diff options
| -rw-r--r-- | flag.go | 24 | ||||
| -rw-r--r-- | flag_test.go | 36 |
2 files changed, 60 insertions, 0 deletions
@@ -145,6 +145,30 @@ func (f *Feature) String() string { // value or an error. type ParseFunc[T any] func(string) (T, error) +// ParseProtobufEnum returns a ParseFunc that will return the +// appropriate enum value or a UnknownEnumValueError if the string +// passed did not match any of the values supplied. +// +// Strings are compared in uppercase only, so ‘FOO,’ ‘foo,’, and ‘fOo’ +// all refer to the same value. +// +// Callers should pass the protoc-generated *_value directly. See +// https://developers.google.com/protocol-buffers/docs/reference/go-generated#enum +// for more details. +func ParseProtobufEnum[T ~int32](values map[string]int32) ParseFunc[T] { + valid := make([]string, 0, len(values)) + for val := range values { + valid = append(valid, val) + } + return func(s string) (T, error) { + val, found := values[strings.ToUpper(s)] + if !found { + return 0, UnknownEnumValueError{s, valid} + } + return T(val), nil + } +} + // ParseStringEnum returns a ParseFunc that will return the string // passed if it matched any of the values supplied. If no such match is // found, an UnknownEnumValueError is returned. diff --git a/flag_test.go b/flag_test.go index 0fb83c7..5b0df0e 100644 --- a/flag_test.go +++ b/flag_test.go @@ -143,6 +143,42 @@ func TestInitFlagSet(s *testing.T) { }) } +func TestParseProtobufEnum(s *testing.T) { + t := &core.T{T: s} + + // That type and map emulate code generated by protoc. + type fakeEnum int32 + values := map[string]int32{ + "FAKE_UNKNOWN": 0, + "FOO": 1, + "BAR": 2, + } + parse := core.ParseProtobufEnum[fakeEnum](values) + + t.Run("Match", func(t *core.T) { + val, err := parse("FOO") + t.AssertErrorIs(nil, err) + t.AssertEqual(fakeEnum(1), val) + }) + + t.Run("MatchCase", func(t *core.T) { + val, err := parse("Foo") + t.AssertErrorIs(nil, err) + t.AssertEqual(fakeEnum(1), val) + }) + + t.Run("UnknownValue", func(t *core.T) { + val, err := parse("BAZ") + var exp core.UnknownEnumValueError + if t.AssertErrorAs(&exp, err) { + t.AssertEqual("BAZ", exp.Actual) + sort.Strings(exp.Expected) + t.AssertEqual([]string{"BAR", "FAKE_UNKNOWN", "FOO"}, exp.Expected) + } + t.AssertEqual(fakeEnum(0), val) + }) +} + func TestParseStringEnum(s *testing.T) { t := &core.T{T: s} parse := core.ParseStringEnum("foo", "bar") |
