aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrégoire Duchêne <gduchene@awhk.org>2025-05-25 09:51:27 +0100
committerGrégoire Duchêne <gduchene@awhk.org>2025-05-25 09:51:27 +0100
commit30dd5b9f30912dbcc9c993ed0270be184d75df61 (patch)
treee02fb682c3c4b366eb2fed2091243d0938568908
parentbca1b0ccdece7b992cc88712aca36a9a127b1381 (diff)
Introduce a Promise typeHEADmain
-rw-r--r--promise.go45
-rw-r--r--promise_test.go89
2 files changed, 134 insertions, 0 deletions
diff --git a/promise.go b/promise.go
new file mode 100644
index 0000000..29bd703
--- /dev/null
+++ b/promise.go
@@ -0,0 +1,45 @@
+// SPDX-FileCopyrightText: © 2024 Grégoire Duchêne <gduchene@awhk.org>
+// SPDX-License-Identifier: ISC
+
+package core
+
+import (
+ "errors"
+ "sync/atomic"
+)
+
+var ErrPromiseFulfilled = errors.New("promise fulfilled already")
+
+type Promise[T any] struct {
+ value chan T
+ error chan error
+ closed int32
+
+ _ NoCopy
+}
+
+func NewPromise[T any]() *Promise[T] {
+ return &Promise[T]{value: make(chan T, 1), error: make(chan error, 1), closed: 0}
+}
+
+func (p *Promise[T]) Err() <-chan error { return p.error }
+
+func (p *Promise[T]) FailWith(err error) error {
+ if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
+ return ErrPromiseFulfilled
+ }
+ p.error <- err
+ close(p.error)
+ return nil
+}
+
+func (p *Promise[T]) SucceedWith(value T) error {
+ if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
+ return ErrPromiseFulfilled
+ }
+ p.value <- value
+ close(p.value)
+ return nil
+}
+
+func (p *Promise[T]) Value() <-chan T { return p.value }
diff --git a/promise_test.go b/promise_test.go
new file mode 100644
index 0000000..bafa184
--- /dev/null
+++ b/promise_test.go
@@ -0,0 +1,89 @@
+// SPDX-FileCopyrightText: © 2024 Grégoire Duchêne <gduchene@awhk.org>
+// SPDX-License-Identifier: ISC
+
+package core_test
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ "go.awhk.org/core"
+)
+
+func TestPromise(s *testing.T) {
+ t := core.T{T: s}
+ someError := errors.New("some error")
+
+ t.Run("Success", func(t *core.T) {
+ p := core.NewPromise[int]()
+
+ t.AssertErrorIs(nil, p.SucceedWith(1))
+ t.AssertEqual(1, <-p.Value())
+ })
+
+ t.Run("SuccessThenError", func(t *core.T) {
+ p := core.NewPromise[int]()
+
+ t.AssertErrorIs(nil, p.SucceedWith(1))
+ t.AssertErrorIs(core.ErrPromiseFulfilled, p.FailWith(someError))
+ t.AssertEqual(1, <-p.Value())
+ })
+
+ t.Run("Error", func(t *core.T) {
+ p := core.NewPromise[int]()
+
+ t.AssertErrorIs(nil, p.FailWith(someError))
+ t.AssertErrorIs(someError, <-p.Err())
+ })
+
+ t.Run("ErrorThenSuccess", func(t *core.T) {
+ p := core.NewPromise[int]()
+
+ t.AssertErrorIs(nil, p.FailWith(someError))
+ t.AssertErrorIs(core.ErrPromiseFulfilled, p.SucceedWith(1))
+ t.AssertErrorIs(someError, <-p.Err())
+ })
+}
+
+func ExamplePromise() {
+ p := core.NewPromise[string]()
+
+ go func() {
+ time.Sleep(time.Millisecond)
+ p.SucceedWith("Hello World!")
+ }()
+
+ select {
+ case s := <-p.Value():
+ fmt.Printf("Received %q.\n", s)
+ case err := <-p.Err():
+ fmt.Printf("Received an error: %s.\n", err)
+ }
+ // Output: Received "Hello World!".
+}
+
+func ExamplePromise_withContext() {
+ var (
+ ctx, cancel = context.WithTimeout(context.Background(), time.Second)
+ p = core.NewPromise[string]()
+ )
+ defer cancel()
+
+ go func() {
+ time.Sleep(time.Millisecond)
+ p.FailWith(errors.New("some error"))
+ }()
+
+ select {
+ case s := <-p.Value():
+ fmt.Printf("Received %q.\n", s)
+ case err := <-p.Err():
+ fmt.Printf("Received an error: %s.\n", err)
+ case <-ctx.Done():
+ fmt.Printf("Context was cancelled: %s.\n", ctx.Err())
+ }
+ // Output: Received an error: some error.
+}