package pipespy_test import ( "errors" "fmt" "io" "sync" "testing" "github.com/chathaway-codes/home-sensors/v2/internal/pipespy" "github.com/google/go-cmp/cmp" ) func TestPipeSpy(t *testing.T) { pipe := pipespy.New() readPipe, writerPipe := io.Pipe() stage1 := &simpleWriter{t: t, stdIn: readPipe} stage2 := &simpleWriter{t: t} stdOut1 := pipe.Add(stage1).Snoop() stdOut2 := pipe.Add(stage2).Snoop() pipe.Add(&simpleWriter{t: t}) // Spawn threads to keep reading the snoop var buff1, buff2 []byte var err1, err2 error var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() buff1, err1 = io.ReadAll(stdOut1) if errors.Is(err1, io.EOF) { err1 = nil } }() go func() { defer wg.Done() buff2, err2 = io.ReadAll(stdOut2) if errors.Is(err2, io.EOF) { err2 = nil } }() wait := pipe.Start() msg := "This is a message!" fmt.Fprint(writerPipe, msg) if err := writerPipe.Close(); err != nil { t.Fatalf("failed to close pipe: %v", err) } wait() wg.Wait() if err1 != nil || err2 != nil { t.Errorf("Got errs %v and %v; want nil or EOF", err1, err2) } if diff := cmp.Diff(string(buff1), msg); diff != "" { t.Errorf("got diff:\n%s", diff) } if diff := cmp.Diff(string(buff2), msg); diff != "" { t.Errorf("got diff:\n%s", diff) } } func TestPipeSpyError(t *testing.T) { readPipe, writerPipe := io.Pipe() pipe := pipespy.New() pipe.Add(&simpleWriter{t: t, stdIn: readPipe}) pipe.Add(&simpleWriter{t: t, err: fmt.Errorf("this is an error")}) pipe.Add(&simpleWriter{t: t}) wait := pipe.Start() if _, err := fmt.Fprintf(writerPipe, "Hello"); err != nil { t.Fatalf("error: %v", err) } writerPipe.Close() errs := wait() if len(errs) != 1 { t.Errorf("Expected 1 err; got %d", len(errs)) } } type simpleWriter struct { stdIn io.Reader stdOut io.Writer gotData []byte err error t *testing.T } func (s *simpleWriter) SetStdin(r io.Reader) { s.stdIn = r } func (s *simpleWriter) SetStdout(w io.Writer) { s.stdOut = w } func (s *simpleWriter) Run() error { if s.err != nil { return s.err } bytes, err := io.ReadAll(s.stdIn) if err != nil && !errors.Is(err, io.EOF) { s.t.Errorf("Unexpected err: %v", err) } s.gotData = bytes s.stdOut.Write(bytes) return nil }