aboutsummaryrefslogtreecommitdiff
path: root/flag.go
diff options
context:
space:
mode:
authorGrégoire Duchêne <gduchene@awhk.org>2022-12-03 11:58:01 +0000
committerGrégoire Duchêne <gduchene@awhk.org>2022-12-03 11:58:01 +0000
commit552c8aaa156f39725329c6d00ea715e0312c6fb8 (patch)
tree53cb46ed59c6f129ca5702d256511b3ea9f77787 /flag.go
parent5b88af5109405a51b0c6f8237016707604109ca7 (diff)
Add InitFlagSet
Diffstat (limited to 'flag.go')
-rw-r--r--flag.go49
1 files changed, 49 insertions, 0 deletions
diff --git a/flag.go b/flag.go
index 9200b6f..d48c091 100644
--- a/flag.go
+++ b/flag.go
@@ -55,6 +55,53 @@ func FlagTSliceVar[T any](fs *flag.FlagSet, p *[]T, name string, values []T, usa
fs.Var(&flagValueSlice[T]{Parse: parse, Separator: sep, Values: p}, name, usage)
}
+// InitFlagSet initializes a flag.FlagSet by setting flags in the
+// following order: environment variables, then an arbitrary map, then
+// command line arguments.
+//
+// Note that InitFlagSet does not require the use of any of the Flag
+// functions defined in this package. Standard flags will work just as
+// well.
+func InitFlagSet(fs *flag.FlagSet, env []string, cfg map[string]string, args []string) (err error) {
+ var environ map[string]string
+ if env != nil {
+ environ = make(map[string]string, len(env))
+ for _, kv := range env {
+ if buf := strings.SplitN(kv, "=", 2); len(buf) == 2 {
+ environ[buf[0]] = buf[1]
+ continue
+ }
+ if val, ok := os.LookupEnv(kv); ok {
+ environ[kv] = val
+ }
+ }
+ }
+
+ fs.VisitAll(func(f *flag.Flag) {
+ if err != nil {
+ return
+ }
+
+ var next string
+ if val, found := environ[strings.ToUpper(strings.ReplaceAll(f.Name, "-", "_"))]; found {
+ next = val
+ }
+ if val, found := cfg[f.Name]; found {
+ next = val
+ }
+ if next != "" {
+ err = f.Value.Set(next)
+ }
+ if f, ok := f.Value.(interface{ resetShouldAppend() }); ok {
+ f.resetShouldAppend()
+ }
+ })
+ if err == nil && !fs.Parsed() {
+ return fs.Parse(args)
+ }
+ return err
+}
+
// ParseString returns the string passed with no error set.
func ParseString(s string) (string, error) {
return s, nil
@@ -130,3 +177,5 @@ func (f *flagValueSlice[T]) String() string {
}
return fmt.Sprintf("%v", *f.Values)
}
+
+func (f *flagValueSlice[T]) resetShouldAppend() { f.shouldAppend = false }