From a99d7b7e2c57380fbcf85c571194df26d86f9d76 Mon Sep 17 00:00:00 2001 From: Grégoire Duchêne Date: Sun, 26 Jun 2022 00:31:10 +0100 Subject: Add a few flag helper functions --- flag.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ flag_test.go | 50 ++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 flag.go create mode 100644 flag_test.go diff --git a/flag.go b/flag.go new file mode 100644 index 0000000..bed32eb --- /dev/null +++ b/flag.go @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: © 2022 Grégoire Duchêne +// SPDX-License-Identifier: ISC + +package core + +import ( + "flag" + "fmt" + "time" +) + +func FlagVar[T any](fs *flag.FlagSet, name, usage string, parse ParseFunc[T]) *T { + v := &flagValue[T]{Parse: parse, Value: new(T)} + fs.Var(v, name, usage) + return v.Value +} + +func FlagVarPtr[T any](fs *flag.FlagSet, name, usage string, parse ParseFunc[T], val *T) { + fs.Var(&flagValue[T]{Parse: parse, Value: val}, name, usage) +} + +func FlagVarSlice[T any](fs *flag.FlagSet, name, usage string, parse ParseFunc[T]) *[]T { + v := &flagValueSlice[T]{Parse: parse, Values: new([]T)} + fs.Var(v, name, usage) + return v.Values +} + +func FlagVarSlicePtr[T any](fs *flag.FlagSet, name, usage string, parse ParseFunc[T], vals *[]T) { + fs.Var(&flagValueSlice[T]{Parse: parse, Values: vals}, name, usage) +} + +// ParseString returns the string passed with no error set. +func ParseString(s string) (string, error) { + return s, nil +} + +// ParseTime parses a string according to the time.RFC3339 format. +func ParseTime(s string) (time.Time, error) { + return time.Parse(time.RFC3339, s) +} + +// ParseFunc describes functions that will parse a string and return a +// value or an error. +type ParseFunc[T any] func(string) (T, error) + +type flagValue[T any] struct { + Parse ParseFunc[T] + Value *T +} + +var _ flag.Value = &flagValue[any]{} + +func (f *flagValue[T]) Set(s string) error { + val, err := f.Parse(s) + if err != nil { + return err + } + *f.Value = val + return nil +} + +func (f *flagValue[T]) String() string { + return fmt.Sprintf("%v", f.Value) +} + +type flagValueSlice[T any] struct { + Parse ParseFunc[T] + Values *[]T +} + +var _ flag.Value = &flagValueSlice[any]{} + +func (f *flagValueSlice[T]) Set(s string) error { + val, err := f.Parse(s) + if err != nil { + return err + } + *f.Values = append(*f.Values, val) + return nil +} + +func (f *flagValueSlice[T]) String() string { + return fmt.Sprintf("%v", f.Values) +} diff --git a/flag_test.go b/flag_test.go new file mode 100644 index 0000000..56711d5 --- /dev/null +++ b/flag_test.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: © 2022 Grégoire Duchêne +// SPDX-License-Identifier: ISC + +package core_test + +import ( + "flag" + "strconv" + "testing" + + "go.awhk.org/core" +) + +func TestFlagVar(s *testing.T) { + t := core.T{T: s} + + fs := flag.NewFlagSet("", flag.ContinueOnError) + fl := core.FlagVar(fs, "test", "", strconv.ParseBool) + t.AssertErrorIs(nil, fs.Parse([]string{"-test=true"})) + t.AssertEqual(true, *fl) +} + +func TestFlagVarPtr(s *testing.T) { + t := core.T{T: s} + + fs := flag.NewFlagSet("", flag.ContinueOnError) + fl := false + core.FlagVarPtr(fs, "test", "", strconv.ParseBool, &fl) + t.AssertErrorIs(nil, fs.Parse([]string{"-test=true"})) + t.AssertEqual(true, fl) +} + +func TestFlagVarSlice(s *testing.T) { + t := core.T{T: s} + + fs := flag.NewFlagSet("", flag.ContinueOnError) + fl := core.FlagVarSlice(fs, "test", "", strconv.Atoi) + t.AssertErrorIs(nil, fs.Parse([]string{"-test=1", "-test=2", "-test=42"})) + t.AssertEqual([]int{1, 2, 42}, *fl) +} + +func TestFlagVarSlicePtr(s *testing.T) { + t := core.T{T: s} + + fs := flag.NewFlagSet("", flag.ContinueOnError) + fl := []int{} + core.FlagVarSlicePtr(fs, "test", "", strconv.Atoi, &fl) + t.AssertErrorIs(nil, fs.Parse([]string{"-test=1", "-test=2", "-test=42"})) + t.AssertEqual([]int{1, 2, 42}, fl) +} -- cgit v1.2.3-70-g09d2