From 30dd5b9f30912dbcc9c993ed0270be184d75df61 Mon Sep 17 00:00:00 2001 From: Grégoire Duchêne Date: Sun, 25 May 2025 09:51:27 +0100 Subject: Introduce a Promise type --- promise.go | 45 +++++++++++++++++++++++++++++ promise_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 promise.go create mode 100644 promise_test.go diff --git a/promise.go b/promise.go new file mode 100644 index 0000000..29bd703 --- /dev/null +++ b/promise.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: © 2024 Grégoire Duchêne +// SPDX-License-Identifier: ISC + +package core + +import ( + "errors" + "sync/atomic" +) + +var ErrPromiseFulfilled = errors.New("promise fulfilled already") + +type Promise[T any] struct { + value chan T + error chan error + closed int32 + + _ NoCopy +} + +func NewPromise[T any]() *Promise[T] { + return &Promise[T]{value: make(chan T, 1), error: make(chan error, 1), closed: 0} +} + +func (p *Promise[T]) Err() <-chan error { return p.error } + +func (p *Promise[T]) FailWith(err error) error { + if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) { + return ErrPromiseFulfilled + } + p.error <- err + close(p.error) + return nil +} + +func (p *Promise[T]) SucceedWith(value T) error { + if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) { + return ErrPromiseFulfilled + } + p.value <- value + close(p.value) + return nil +} + +func (p *Promise[T]) Value() <-chan T { return p.value } diff --git a/promise_test.go b/promise_test.go new file mode 100644 index 0000000..bafa184 --- /dev/null +++ b/promise_test.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: © 2024 Grégoire Duchêne +// SPDX-License-Identifier: ISC + +package core_test + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "go.awhk.org/core" +) + +func TestPromise(s *testing.T) { + t := core.T{T: s} + someError := errors.New("some error") + + t.Run("Success", func(t *core.T) { + p := core.NewPromise[int]() + + t.AssertErrorIs(nil, p.SucceedWith(1)) + t.AssertEqual(1, <-p.Value()) + }) + + t.Run("SuccessThenError", func(t *core.T) { + p := core.NewPromise[int]() + + t.AssertErrorIs(nil, p.SucceedWith(1)) + t.AssertErrorIs(core.ErrPromiseFulfilled, p.FailWith(someError)) + t.AssertEqual(1, <-p.Value()) + }) + + t.Run("Error", func(t *core.T) { + p := core.NewPromise[int]() + + t.AssertErrorIs(nil, p.FailWith(someError)) + t.AssertErrorIs(someError, <-p.Err()) + }) + + t.Run("ErrorThenSuccess", func(t *core.T) { + p := core.NewPromise[int]() + + t.AssertErrorIs(nil, p.FailWith(someError)) + t.AssertErrorIs(core.ErrPromiseFulfilled, p.SucceedWith(1)) + t.AssertErrorIs(someError, <-p.Err()) + }) +} + +func ExamplePromise() { + p := core.NewPromise[string]() + + go func() { + time.Sleep(time.Millisecond) + p.SucceedWith("Hello World!") + }() + + select { + case s := <-p.Value(): + fmt.Printf("Received %q.\n", s) + case err := <-p.Err(): + fmt.Printf("Received an error: %s.\n", err) + } + // Output: Received "Hello World!". +} + +func ExamplePromise_withContext() { + var ( + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + p = core.NewPromise[string]() + ) + defer cancel() + + go func() { + time.Sleep(time.Millisecond) + p.FailWith(errors.New("some error")) + }() + + select { + case s := <-p.Value(): + fmt.Printf("Received %q.\n", s) + case err := <-p.Err(): + fmt.Printf("Received an error: %s.\n", err) + case <-ctx.Done(): + fmt.Printf("Context was cancelled: %s.\n", ctx.Err()) + } + // Output: Received an error: some error. +} -- cgit v1.2.3-70-g09d2