diff options
| author | Grégoire Duchêne <gduchene@awhk.org> | 2022-12-10 14:21:12 +0000 |
|---|---|---|
| committer | Grégoire Duchêne <gduchene@awhk.org> | 2022-12-10 14:39:04 +0000 |
| commit | 3cc43119b40d3a556ae818b69bad5d977cc24014 (patch) | |
| tree | 5b9e557d266169e59ffe2b9b29159078cc75ca40 | |
| parent | 2df5e154434bce61c3e4aa9626b69f8ef5b80598 (diff) | |
Add ParseStringEnum
| -rw-r--r-- | flag.go | 34 | ||||
| -rw-r--r-- | flag_test.go | 27 |
2 files changed, 61 insertions, 0 deletions
@@ -7,6 +7,7 @@ import ( "flag" "fmt" "os" + "sort" "strconv" "strings" "sync/atomic" @@ -144,11 +145,44 @@ func (f *Feature) String() string { // value or an error. type ParseFunc[T any] func(string) (T, error) +// ParseStringEnum returns a ParseFunc that will return the string +// passed if it matched any of the values supplied. If no such match is +// found, an UnknownEnumValueError is returned. +// +// Note that unlike ParseProtobufEnum, comparison is case-sensitive. +func ParseStringEnum(values ...string) ParseFunc[string] { + return func(s string) (string, error) { + for _, val := range values { + if s == val { + return s, nil + } + } + return "", UnknownEnumValueError{s, values} + } +} + // ParseTime parses a string according to the time.RFC3339 format. func ParseTime(s string) (time.Time, error) { return time.Parse(time.RFC3339, s) } +// UnknownEnumValueError is returned by the functions produced by +// ParseProtobufEnum and ParseStringEnum when an unknown value is +// encountered. +type UnknownEnumValueError struct { + Actual string + Expected []string +} + +func (err UnknownEnumValueError) Error() string { + // We sort the expected values so the output is deterministic, which may + // be useful when parsing logs or otherwise examining program output. + if !sort.StringsAreSorted(err.Expected) { + sort.Strings(err.Expected) + } + return fmt.Sprintf("unknown value %s, expected one of %s", err.Actual, err.Expected) +} + type flagFeature struct{ *Feature } func (flagFeature) IsBoolFlag() bool { return true } diff --git a/flag_test.go b/flag_test.go index 1cc9ba8..0fb83c7 100644 --- a/flag_test.go +++ b/flag_test.go @@ -5,6 +5,7 @@ package core_test import ( "flag" + "sort" "strconv" "testing" @@ -141,3 +142,29 @@ func TestInitFlagSet(s *testing.T) { t.AssertEqual(42, *fi) }) } + +func TestParseStringEnum(s *testing.T) { + t := &core.T{T: s} + parse := core.ParseStringEnum("foo", "bar") + + t.Run("Match", func(t *core.T) { + val, err := parse("foo") + t.AssertErrorIs(nil, err) + t.AssertEqual("foo", val) + + val, err = parse("bar") + t.AssertErrorIs(nil, err) + t.AssertEqual("bar", val) + }) + + t.Run("UnknownValue", func(t *core.T) { + val, err := parse("baz") + var exp core.UnknownEnumValueError + if t.AssertErrorAs(&exp, err) { + t.AssertEqual("baz", exp.Actual) + sort.Strings(exp.Expected) + t.AssertEqual([]string{"bar", "foo"}, exp.Expected) + } + t.AssertEqual("", val) + }) +} |
