aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrégoire Duchêne <gduchene@awhk.org>2022-12-10 14:21:12 +0000
committerGrégoire Duchêne <gduchene@awhk.org>2022-12-10 14:39:04 +0000
commit3cc43119b40d3a556ae818b69bad5d977cc24014 (patch)
tree5b9e557d266169e59ffe2b9b29159078cc75ca40
parent2df5e154434bce61c3e4aa9626b69f8ef5b80598 (diff)
Add ParseStringEnum
-rw-r--r--flag.go34
-rw-r--r--flag_test.go27
2 files changed, 61 insertions, 0 deletions
diff --git a/flag.go b/flag.go
index 51c78aa..01e5f54 100644
--- a/flag.go
+++ b/flag.go
@@ -7,6 +7,7 @@ import (
"flag"
"fmt"
"os"
+ "sort"
"strconv"
"strings"
"sync/atomic"
@@ -144,11 +145,44 @@ func (f *Feature) String() string {
// value or an error.
type ParseFunc[T any] func(string) (T, error)
+// ParseStringEnum returns a ParseFunc that will return the string
+// passed if it matched any of the values supplied. If no such match is
+// found, an UnknownEnumValueError is returned.
+//
+// Note that unlike ParseProtobufEnum, comparison is case-sensitive.
+func ParseStringEnum(values ...string) ParseFunc[string] {
+ return func(s string) (string, error) {
+ for _, val := range values {
+ if s == val {
+ return s, nil
+ }
+ }
+ return "", UnknownEnumValueError{s, values}
+ }
+}
+
// ParseTime parses a string according to the time.RFC3339 format.
func ParseTime(s string) (time.Time, error) {
return time.Parse(time.RFC3339, s)
}
+// UnknownEnumValueError is returned by the functions produced by
+// ParseProtobufEnum and ParseStringEnum when an unknown value is
+// encountered.
+type UnknownEnumValueError struct {
+ Actual string
+ Expected []string
+}
+
+func (err UnknownEnumValueError) Error() string {
+ // We sort the expected values so the output is deterministic, which may
+ // be useful when parsing logs or otherwise examining program output.
+ if !sort.StringsAreSorted(err.Expected) {
+ sort.Strings(err.Expected)
+ }
+ return fmt.Sprintf("unknown value %s, expected one of %s", err.Actual, err.Expected)
+}
+
type flagFeature struct{ *Feature }
func (flagFeature) IsBoolFlag() bool { return true }
diff --git a/flag_test.go b/flag_test.go
index 1cc9ba8..0fb83c7 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -5,6 +5,7 @@ package core_test
import (
"flag"
+ "sort"
"strconv"
"testing"
@@ -141,3 +142,29 @@ func TestInitFlagSet(s *testing.T) {
t.AssertEqual(42, *fi)
})
}
+
+func TestParseStringEnum(s *testing.T) {
+ t := &core.T{T: s}
+ parse := core.ParseStringEnum("foo", "bar")
+
+ t.Run("Match", func(t *core.T) {
+ val, err := parse("foo")
+ t.AssertErrorIs(nil, err)
+ t.AssertEqual("foo", val)
+
+ val, err = parse("bar")
+ t.AssertErrorIs(nil, err)
+ t.AssertEqual("bar", val)
+ })
+
+ t.Run("UnknownValue", func(t *core.T) {
+ val, err := parse("baz")
+ var exp core.UnknownEnumValueError
+ if t.AssertErrorAs(&exp, err) {
+ t.AssertEqual("baz", exp.Actual)
+ sort.Strings(exp.Expected)
+ t.AssertEqual([]string{"bar", "foo"}, exp.Expected)
+ }
+ t.AssertEqual("", val)
+ })
+}