diff --git a/flag_bool_with_inverse.go b/flag_bool_with_inverse.go index cc401bb9fc..ecc090b478 100644 --- a/flag_bool_with_inverse.go +++ b/flag_bool_with_inverse.go @@ -106,7 +106,7 @@ func (parent *BoolWithInverseFlag) initialize() { sources := []ValueSource{} for _, envVar := range child.GetEnvVars() { - sources = append(sources, &envVarValueSource{Key: strings.ToUpper(parent.InversePrefix) + envVar}) + sources = append(sources, EnvVar(strings.ToUpper(parent.InversePrefix)+envVar)) } parent.negativeFlag.Sources = NewValueSourceChain(sources...) } diff --git a/godoc-current.txt b/godoc-current.txt index 079b09b459..3550ebe1f4 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -591,6 +591,12 @@ type DocGenerationMultiValueFlag interface { type DurationFlag = FlagBase[time.Duration, NoConfig, durationValue] +type EnvValueSource interface { + IsFromEnv() bool + Key() string +} + EnvValueSource is to specifically detect env sources when printing help text + type ErrorFormatter interface { Format(s fmt.State, verb rune) } @@ -848,6 +854,20 @@ func (i MapBase[T, C, VC]) ToString(t map[string]T) string func (i *MapBase[T, C, VC]) Value() map[string]T Value returns the mapping of values set by this flag +type MapSource interface { + fmt.Stringer + fmt.GoStringer + + // Lookup returns the value from the source based on key + // and if it was found + // or returns an empty string and false + Lookup(string) (any, bool) +} + MapSource is a source which can be used to look up a value based on a key + typically for use with a cli.Flag + +func NewMapSource(name string, m map[any]any) MapSource + type MultiError interface { error Errors() []error @@ -1000,6 +1020,8 @@ func EnvVar(key string) ValueSource func File(path string) ValueSource +func NewMapValueSource(key string, ms MapSource) ValueSource + type ValueSourceChain struct { Chain []ValueSource } diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index 079b09b459..3550ebe1f4 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -591,6 +591,12 @@ type DocGenerationMultiValueFlag interface { type DurationFlag = FlagBase[time.Duration, NoConfig, durationValue] +type EnvValueSource interface { + IsFromEnv() bool + Key() string +} + EnvValueSource is to specifically detect env sources when printing help text + type ErrorFormatter interface { Format(s fmt.State, verb rune) } @@ -848,6 +854,20 @@ func (i MapBase[T, C, VC]) ToString(t map[string]T) string func (i *MapBase[T, C, VC]) Value() map[string]T Value returns the mapping of values set by this flag +type MapSource interface { + fmt.Stringer + fmt.GoStringer + + // Lookup returns the value from the source based on key + // and if it was found + // or returns an empty string and false + Lookup(string) (any, bool) +} + MapSource is a source which can be used to look up a value based on a key + typically for use with a cli.Flag + +func NewMapSource(name string, m map[any]any) MapSource + type MultiError interface { error Errors() []error @@ -1000,6 +1020,8 @@ func EnvVar(key string) ValueSource func File(path string) ValueSource +func NewMapValueSource(key string, ms MapSource) ValueSource + type ValueSourceChain struct { Chain []ValueSource } diff --git a/value_source.go b/value_source.go index 7d3f7ee45c..edc75d2e4f 100644 --- a/value_source.go +++ b/value_source.go @@ -17,6 +17,26 @@ type ValueSource interface { Lookup() (string, bool) } +// EnvValueSource is to specifically detect env sources when +// printing help text +type EnvValueSource interface { + IsFromEnv() bool + Key() string +} + +// MapSource is a source which can be used to look up a value +// based on a key +// typically for use with a cli.Flag +type MapSource interface { + fmt.Stringer + fmt.GoStringer + + // Lookup returns the value from the source based on key + // and if it was found + // or returns an empty string and false + Lookup(string) (any, bool) +} + // ValueSourceChain contains an ordered series of ValueSource that // allows for lookup where the first ValueSource to resolve is // returned @@ -38,8 +58,8 @@ func (vsc *ValueSourceChain) EnvKeys() []string { vals := []string{} for _, src := range vsc.Chain { - if v, ok := src.(*envVarValueSource); ok { - vals = append(vals, v.Key) + if v, ok := src.(EnvValueSource); ok && v.IsFromEnv() { + vals = append(vals, v.Key()) } } @@ -83,21 +103,29 @@ func (vsc *ValueSourceChain) LookupWithSource() (string, ValueSource, bool) { // envVarValueSource encapsulates a ValueSource from an environment variable type envVarValueSource struct { - Key string + key string } func (e *envVarValueSource) Lookup() (string, bool) { - return os.LookupEnv(strings.TrimSpace(string(e.Key))) + return os.LookupEnv(strings.TrimSpace(string(e.key))) } -func (e *envVarValueSource) String() string { return fmt.Sprintf("environment variable %[1]q", e.Key) } +func (e *envVarValueSource) IsFromEnv() bool { + return true +} + +func (e *envVarValueSource) Key() string { + return e.key +} + +func (e *envVarValueSource) String() string { return fmt.Sprintf("environment variable %[1]q", e.key) } func (e *envVarValueSource) GoString() string { - return fmt.Sprintf("&envVarValueSource{Key:%[1]q}", e.Key) + return fmt.Sprintf("&envVarValueSource{Key:%[1]q}", e.key) } func EnvVar(key string) ValueSource { return &envVarValueSource{ - Key: key, + key: key, } } @@ -107,7 +135,7 @@ func EnvVars(keys ...string) ValueSourceChain { vsc := ValueSourceChain{Chain: []ValueSource{}} for _, key := range keys { - vsc.Chain = append(vsc.Chain, &envVarValueSource{Key: key}) + vsc.Chain = append(vsc.Chain, EnvVar(key)) } return vsc @@ -138,8 +166,85 @@ func Files(paths ...string) ValueSourceChain { vsc := ValueSourceChain{Chain: []ValueSource{}} for _, path := range paths { - vsc.Chain = append(vsc.Chain, &fileValueSource{Path: path}) + vsc.Chain = append(vsc.Chain, File(path)) } return vsc } + +type mapSource struct { + name string + m map[any]any +} + +func NewMapSource(name string, m map[any]any) MapSource { + return &mapSource{ + name: name, + m: m, + } +} + +func (ms *mapSource) String() string { return fmt.Sprintf("map source %[1]q", ms.name) } +func (ms *mapSource) GoString() string { + return fmt.Sprintf("&mapSource{name:%[1]q}", ms.name) +} + +func (ms *mapSource) Lookup(name string) (any, bool) { + // nestedVal checks if the name has '.' delimiters. + // If so, it tries to traverse the tree by the '.' delimited sections to find + // a nested value for the key. + if sections := strings.Split(name, "."); len(sections) > 1 { + node := ms.m + for _, section := range sections[:len(sections)-1] { + child, ok := node[section] + if !ok { + return nil, false + } + + switch child := child.(type) { + case map[string]any: + node = make(map[any]any, len(child)) + for k, v := range child { + node[k] = v + } + case map[any]any: + node = child + default: + return nil, false + } + } + if val, ok := node[sections[len(sections)-1]]; ok { + return val, true + } + } + + return nil, false +} + +type mapValueSource struct { + key string + ms MapSource +} + +func NewMapValueSource(key string, ms MapSource) ValueSource { + return &mapValueSource{ + key: key, + ms: ms, + } +} + +func (mvs *mapValueSource) String() string { + return fmt.Sprintf("key %[1]q from %[2]s", mvs.key, mvs.ms.String()) +} + +func (mvs *mapValueSource) GoString() string { + return fmt.Sprintf("&mapValueSource{key:%[1]q, src:%[2]s}", mvs.key, mvs.ms.GoString()) +} + +func (mvs *mapValueSource) Lookup() (string, bool) { + if v, ok := mvs.ms.Lookup(mvs.key); !ok { + return "", false + } else { + return fmt.Sprintf("%+v", v), true + } +} diff --git a/value_source_test.go b/value_source_test.go index 57e9d49e28..fa02d1f54e 100644 --- a/value_source_test.go +++ b/value_source_test.go @@ -189,3 +189,142 @@ func (svs *staticValueSource) GoString() string { } func (svs *staticValueSource) String() string { return svs.v } func (svs *staticValueSource) Lookup() (string, bool) { return svs.v, true } + +func TestMapValueSource(t *testing.T) { + tests := []struct { + name string + m map[any]any + key string + val string + found bool + }{ + { + name: "No map no key", + }, + { + name: "No map with key", + key: "foo", + }, + { + name: "Empty map no key", + m: map[any]any{}, + }, + { + name: "Empty map with key", + key: "foo", + m: map[any]any{}, + }, + { + name: "Level 1 no key", + key: ".foob", + m: map[any]any{ + "foo": 10, + }, + }, + { + name: "Level 2", + key: "foo.bar", + m: map[any]any{ + "foo": map[any]any{ + "bar": 10, + }, + }, + val: "10", + found: true, + }, + { + name: "Level 2 invalid key", + key: "foo.bar1", + m: map[any]any{ + "foo": map[any]any{ + "bar": "10", + }, + }, + }, + { + name: "Level 3 no entry", + key: "foo.bar.t", + m: map[any]any{ + "foo": map[any]any{ + "bar": "sss", + }, + }, + }, + { + name: "Level 3", + key: "foo.bar.t", + m: map[any]any{ + "foo": map[any]any{ + "bar": map[any]any{ + "t": "sss", + }, + }, + }, + val: "sss", + found: true, + }, + { + name: "Level 3 invalid key", + key: "foo.bar.t", + m: map[any]any{ + "foo": map[any]any{ + "bar": map[any]any{ + "t1": 10, + }, + }, + }, + }, + { + name: "Level 4 no entry", + key: "foo.bar.t.gh", + m: map[any]any{ + "foo": map[any]any{ + "bar": map[any]any{ + "t1": 10, + }, + }, + }, + }, + { + name: "Level 4 slice entry", + key: "foo.bar.t.gh", + m: map[any]any{ + "foo": map[any]any{ + "bar": map[string]any{ + "t": map[any]any{ + "gh": []int{10}, + }, + }, + }, + }, + val: "[10]", + found: true, + }, + } + + for _, test := range tests { + t.Run(test.key, func(t *testing.T) { + ms := NewMapSource("test", test.m) + m := NewMapValueSource(test.key, ms) + val, b := m.Lookup() + if !test.found { + assert.False(t, b) + } else { + assert.True(t, b) + assert.Equal(t, val, test.val) + } + }) + } +} + +func TestMapValueSourceStringer(t *testing.T) { + m := map[any]any{ + "foo": map[any]any{ + "bar": 10, + }, + } + mvs := NewMapValueSource("bar", NewMapSource("test", m)) + + assert.Equal(t, `&mapValueSource{key:"bar", src:&mapSource{name:"test"}}`, mvs.GoString()) + assert.Equal(t, `key "bar" from map source "test"`, mvs.String()) +}