diff options
| -rw-r--r-- | go.mod | 2 | ||||
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | net.go | 88 | ||||
| -rw-r--r-- | net_test.go | 50 | ||||
| -rw-r--r-- | testing.go | 111 | ||||
| -rw-r--r-- | util.go | 9 | ||||
| -rw-r--r-- | util_test.go | 20 |
7 files changed, 282 insertions, 0 deletions
@@ -1,3 +1,5 @@ module go.awhk.org/core go 1.18 + +require github.com/google/go-cmp v0.5.8 @@ -0,0 +1,2 @@ +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -0,0 +1,88 @@ +package core + +import ( + "context" + "net" + "strings" + "sync" + "syscall" +) + +// Listen is a wrapper around net.Listen. If addr cannot be split in two +// parts around the first colon found, Listen will try to create a UNIX +// or TCP net.Listener depending on whether addr contains a slash. +func Listen(addr string) (net.Listener, error) { + if fields := strings.SplitN(addr, ":", 2); len(fields) == 2 { + return net.Listen(fields[0], fields[1]) + } + if strings.ContainsRune(addr, '/') { + return net.Listen("unix", addr) + } + return net.Listen("tcp", addr) +} + +// PipeListener is a net.Listener that works over a pipe. It provides +// dialer functions that can be used in an HTTP client or gRPC options. +// +// Its zero value is safe to use. PipeListener must not be copied after +// its first use. +type PipeListener struct { + conns chan net.Conn + done chan struct{} + + closeOnce sync.Once + initOnce sync.Once +} + +var _ net.Listener = &PipeListener{} + +func (p *PipeListener) Accept() (net.Conn, error) { + p.initOnce.Do(p.init) + + select { + case conn := <-p.conns: + return conn, nil + case <-p.done: + return nil, syscall.EINVAL + } +} + +func (p *PipeListener) Addr() net.Addr { return pipeListenerAddr{} } + +func (p *PipeListener) Close() error { + p.initOnce.Do(p.init) + p.closeOnce.Do(func() { close(p.done) }) + return nil +} + +func (p *PipeListener) Dial(_, _ string) (net.Conn, error) { + return p.DialContext(context.Background(), "", "") +} + +func (p *PipeListener) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + p.initOnce.Do(p.init) + + s, c := net.Pipe() + select { + case p.conns <- s: + return c, nil + case <-p.done: + return nil, syscall.ECONNREFUSED + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (p *PipeListener) DialContextGRPC(ctx context.Context, _ string) (net.Conn, error) { + return p.DialContext(ctx, "", "") +} + +func (p *PipeListener) init() { + p.conns = make(chan net.Conn) + p.done = make(chan struct{}) +} + +type pipeListenerAddr struct{} + +func (pipeListenerAddr) Network() string { return "pipe" } +func (pipeListenerAddr) String() string { return "pipe" } diff --git a/net_test.go b/net_test.go new file mode 100644 index 0000000..505579e --- /dev/null +++ b/net_test.go @@ -0,0 +1,50 @@ +package core_test + +import ( + "context" + "syscall" + "testing" + + "go.awhk.org/core" +) + +func TestPipeListener(s *testing.T) { + t := core.T{T: s} + + t.Run("Success", func(t *core.T) { + p := &core.PipeListener{} + + t.Go(func() { + conn, err := p.Accept() + t.AssertErrorIs(nil, err) + t.AssertNotEqual(nil, conn) + }) + + conn, err := p.Dial("", "") + t.AssertErrorIs(nil, err) + t.AssertNotEqual(nil, conn) + }) + + t.Run("WhenClosed", func(t *core.T) { + p := &core.PipeListener{} + p.Close() + + conn, err := p.Accept() + t.AssertErrorIs(syscall.EINVAL, err) + t.AssertEqual(nil, conn) + + conn, err = p.Dial("", "") + t.AssertErrorIs(syscall.ECONNREFUSED, err) + t.AssertEqual(nil, conn) + }) + + t.Run("WhenContextCanceled", func(t *core.T) { + p := &core.PipeListener{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn, err := p.DialContext(ctx, "", "") + t.AssertErrorIs(context.Canceled, err) + t.AssertEqual(nil, conn) + }) +} diff --git a/testing.go b/testing.go new file mode 100644 index 0000000..e18b4e9 --- /dev/null +++ b/testing.go @@ -0,0 +1,111 @@ +package core + +import ( + "errors" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" +) + +type T struct { + *testing.T + Options []cmp.Option + + wg sync.WaitGroup +} + +func (t *T) AssertEqual(exp, actual any) bool { + t.Helper() + + diff := cmp.Diff(exp, actual, t.Options...) + if diff == "" { + return true + } + t.Errorf("\nexpected %#v, got %#v\n%s", exp, actual, diff) + return false +} + +func (t *T) AssertErrorIs(err, target error) bool { + t.Helper() + + if errors.Is(err, target) { + return true + } + t.Errorf("\nexpected error to be %#v, got %#v", err, target) + return false +} + +func (t *T) AssertPanics(f func()) bool { + t.Helper() + return t.AssertPanicsWith(f, nil) +} + +func (t *T) AssertPanicsWith(f func(), exp any) (b bool) { + t.Helper() + + defer func() { + t.Helper() + + actual := recover() + switch { + case actual == nil: + t.Errorf("\nexpected panic") + b = false + case exp == nil: + default: + b = t.AssertEqual(exp, actual) + } + }() + + f() + return true +} + +func (t *T) AssertNotEqual(notExp, actual any) bool { + t.Helper() + + if !cmp.Equal(notExp, actual, t.Options...) { + return true + } + t.Errorf("\nunexpected %#v", actual) + return false +} + +func (t *T) AssertNotPanics(f func()) (b bool) { + t.Helper() + + defer func() { + if actual := recover(); actual != nil { + t.Errorf("\nunexpected panic with %#v", actual) + b = false + } + }() + + f() + return true +} + +func (t *T) Go(f func()) { + t.wg.Add(1) + go func() { + defer t.wg.Done() + f() + }() +} + +func (t *T) Must(b bool) { + if !b { + t.FailNow() + } +} + +func (t *T) Run(name string, f func(t *T)) { + t.T.Run(name, func(s *testing.T) { + t := &T{T: s, Options: t.Options} + f(t) + t.wg.Wait() + }) +} + +func (t *T) Wait() { t.wg.Wait() } @@ -0,0 +1,9 @@ +package core + +// Must panics if err is not nil. It returns val otherwise. +func Must[T any](val T, err error) T { + if err != nil { + panic(err) + } + return val +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..65cc3bc --- /dev/null +++ b/util_test.go @@ -0,0 +1,20 @@ +package core_test + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "go.awhk.org/core" +) + +func TestMust(s *testing.T) { + t := core.T{T: s, Options: []cmp.Option{cmpopts.EquateErrors()}} + + err := errors.New("some error") + t.AssertPanicsWith(func() { core.Must(42, err) }, err) + t.AssertNotPanics(func() { core.Must(42, nil) }) + t.AssertEqual(42, core.Must(42, nil)) +} |
