diff options
| -rw-r--r-- | flag.go | 14 | ||||
| -rw-r--r-- | flag_test.go | 32 |
2 files changed, 46 insertions, 0 deletions
@@ -184,6 +184,20 @@ func ParseStringEnum(values ...string) ParseFunc[string] { } } +// ParseStringerEnum returns a ParseFunc that will return the first +// value having a string value matching the string passed. +func ParseStringerEnum[T fmt.Stringer](values ...T) ParseFunc[T] { + return func(s string) (T, error) { + for _, val := range values { + if s == val.String() { + return val, nil + } + } + var zero T + return zero, UnknownEnumValueError[T]{s, values} + } +} + // ParseTime parses a string according to the time.RFC3339 format. func ParseTime(s string) (time.Time, error) { return time.Parse(time.RFC3339, s) diff --git a/flag_test.go b/flag_test.go index 3273775..411174a 100644 --- a/flag_test.go +++ b/flag_test.go @@ -203,3 +203,35 @@ func TestParseStringEnum(s *testing.T) { t.AssertEqual("", val) }) } + +func TestParseStringerEnum(s *testing.T) { + t := &core.T{T: s, Options: cmp.Options{fakeEnumComparer}} + parser := core.ParseStringerEnum(fakeEnumFoo, fakeEnumBar) + + t.Run("Match", func(t *core.T) { + val, err := parser("FOO") + t.AssertErrorIs(nil, err) + t.AssertEqual(fakeEnumFoo, val) + }) + + t.Run("UnknownValue", func(t *core.T) { + val, err := parser("baz") + var exp core.UnknownEnumValueError[fakeEnum] + if t.AssertErrorAs(&exp, err) { + t.AssertEqual("baz", exp.Actual) + t.AssertEqual([]fakeEnum{fakeEnumFoo, fakeEnumBar}, exp.Expected) + } + t.AssertEqual(fakeEnum{}, val) + }) +} + +type fakeEnum struct{ string } + +var ( + fakeEnumFoo = fakeEnum{"FOO"} + fakeEnumBar = fakeEnum{"BAR"} + + fakeEnumComparer = cmp.Comparer(func(x, y fakeEnum) bool { return x == y }) +) + +func (e fakeEnum) String() string { return e.string } |
