diff options
| -rw-r--r-- | go.mod | 5 | ||||
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | pipeln.go | 19 | ||||
| -rw-r--r-- | pipeln_test.go | 25 |
4 files changed, 38 insertions, 13 deletions
@@ -2,4 +2,7 @@ module go.awhk.org/pipeln go 1.15 -require github.com/stretchr/testify v1.7.0 +require ( + github.com/stretchr/testify v1.7.0 + golang.org/x/sys v0.0.0-20210316092937-0b90fd5c4c48 +) @@ -5,6 +5,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20210316092937-0b90fd5c4c48 h1:70qalHWW1n9yoI8B8zEQxFJO/D6NUWIX8SNmJO+rvNw= +golang.org/x/sys v0.0.0-20210316092937-0b90fd5c4c48/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= @@ -4,13 +4,9 @@ package pipeln import ( "context" - "errors" "net" -) -var ( - ErrBadAddress = errors.New("bad address") - ErrClosed = errors.New("closed listener") + "golang.org/x/sys/unix" ) type addr struct { @@ -34,6 +30,7 @@ type PipeListenerDialer struct { addr string conns chan net.Conn done chan struct{} + ok bool } var _ net.Listener = &PipeListenerDialer{} @@ -43,7 +40,7 @@ func (ln *PipeListenerDialer) Accept() (net.Conn, error) { case conn := <-ln.conns: return conn, nil case <-ln.done: - return nil, ErrClosed + return nil, unix.EINVAL } } @@ -52,20 +49,24 @@ func (ln *PipeListenerDialer) Addr() net.Addr { } func (ln *PipeListenerDialer) Close() error { + if !ln.ok { + return unix.EINVAL + } close(ln.done) + ln.ok = false return nil } func (ln *PipeListenerDialer) Dial(_, addr string) (net.Conn, error) { if addr != ln.addr { - return nil, ErrBadAddress + return nil, unix.EINVAL } s, c := net.Pipe() select { case ln.conns <- s: return c, nil case <-ln.done: - return nil, ErrClosed + return nil, unix.ECONNREFUSED } } @@ -78,5 +79,5 @@ func (ln *PipeListenerDialer) DialContextAddr(_ context.Context, addr string) (n } func New(addr string) *PipeListenerDialer { - return &PipeListenerDialer{addr, make(chan net.Conn), make(chan struct{})} + return &PipeListenerDialer{addr, make(chan net.Conn), make(chan struct{}), true} } diff --git a/pipeln_test.go b/pipeln_test.go index e5de8d0..9673fcc 100644 --- a/pipeln_test.go +++ b/pipeln_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" ) func Test(t *testing.T) { @@ -22,9 +23,27 @@ func Test(t *testing.T) { go srv.Serve(ln) client := http.Client{Transport: &http.Transport{Dial: ln.Dial}} - resp, err := client.Get("http://test/endpoint") - require.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) + + t.Run("OK", func(t *testing.T) { + resp, err := client.Get("http://test/endpoint") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Address Mismatch", func(t *testing.T) { + _, err := client.Get("http://other-test/endpoint") + assert.ErrorIs(t, err, unix.EINVAL) + }) srv.Shutdown(context.Background()) + + t.Run("Remote Connection Closed", func(t *testing.T) { + _, err := client.Get("http://test/endpoint") + assert.ErrorIs(t, err, unix.ECONNREFUSED) + }) + + t.Run("Already-closed Listener", func(t *testing.T) { + srv = http.Server{Handler: mux} + assert.ErrorIs(t, srv.Serve(ln), unix.EINVAL) + }) } |
