diff --git a/CHANGELOG.md b/CHANGELOG.md index cd50a20e..4eeebeee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## Unreleased -- No changes yet. +### Added +- Support for consuming named value groups as `map[string]T` in addition to `[]T`. + Named value groups can now be consumed as maps where names become keys, enabling + both direct named access and map-based access to the same providers. +- Simultaneous `dig.Name()` and `dig.Group()` support, removing previous mutual + exclusivity to enable named value group patterns. +- Comprehensive validation for slice decorators with named value groups, preventing + incompatible patterns and providing clear guidance for correct usage. +- Soft value groups support with map consumption, maintaining consistent behavior + with slice consumption patterns. ## [1.19.0] - 2025-05-13 diff --git a/constructor.go b/constructor.go index d46a273d..a731b183 100644 --- a/constructor.go +++ b/constructor.go @@ -213,7 +213,7 @@ func (n *constructorNode) Call(c containerStore) (err error) { // would be made to a containerWriter and defers them until Commit is called. type stagingContainerWriter struct { values map[key]reflect.Value - groups map[key][]reflect.Value + groups map[key][]keyedGroupValue } var _ containerWriter = (*stagingContainerWriter)(nil) @@ -221,7 +221,7 @@ var _ containerWriter = (*stagingContainerWriter)(nil) func newStagingContainerWriter() *stagingContainerWriter { return &stagingContainerWriter{ values: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), + groups: make(map[key][]keyedGroupValue), } } @@ -233,12 +233,12 @@ func (sr *stagingContainerWriter) setDecoratedValue(_ string, _ reflect.Type, _ digerror.BugPanicf("stagingContainerWriter.setDecoratedValue must never be called") } -func (sr *stagingContainerWriter) submitGroupedValue(group string, t reflect.Type, v reflect.Value) { +func (sr *stagingContainerWriter) submitGroupedValue(group, mapKey string, t reflect.Type, v reflect.Value) { k := key{t: t, group: group} - sr.groups[k] = append(sr.groups[k], v) + sr.groups[k] = append(sr.groups[k], keyedGroupValue{key: mapKey, value: v}) } -func (sr *stagingContainerWriter) submitDecoratedGroupedValue(_ string, _ reflect.Type, _ reflect.Value) { +func (sr *stagingContainerWriter) submitDecoratedGroupedValue(_, _ string, _ reflect.Type, _ reflect.Value) { digerror.BugPanicf("stagingContainerWriter.submitDecoratedGroupedValue must never be called") } @@ -248,9 +248,9 @@ func (sr *stagingContainerWriter) Commit(cw containerWriter) { cw.setValue(k.name, k.t, v) } - for k, vs := range sr.groups { - for _, v := range vs { - cw.submitGroupedValue(k.group, k.t, v) + for k, kgvs := range sr.groups { + for _, kgv := range kgvs { + cw.submitGroupedValue(k.group, kgv.key, k.t, kgv.value) } } } diff --git a/container.go b/container.go index a875b5e0..2d1336a2 100644 --- a/container.go +++ b/container.go @@ -82,12 +82,12 @@ type containerWriter interface { setDecoratedValue(name string, t reflect.Type, v reflect.Value) // submitGroupedValue submits a value to the value group with the provided - // name. - submitGroupedValue(name string, t reflect.Type, v reflect.Value) + // name and optional map key. + submitGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) // submitDecoratedGroupedValue submits a decorated value to the value group - // with the provided name. - submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) + // with the provided name and optional map key. + submitDecoratedGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) } // containerStore provides access to the Container's underlying data store. @@ -109,7 +109,7 @@ type containerStore interface { // Retrieves all values for the provided group and type. // // The order in which the values are returned is undefined. - getValueGroup(name string, t reflect.Type) []reflect.Value + getValueGroup(name string, t reflect.Type) []keyedGroupValue // Retrieves all decorated values for the provided group and type, if any. getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) @@ -292,8 +292,8 @@ func (bs byTypeName) Swap(i int, j int) { bs[i], bs[j] = bs[j], bs[i] } -func shuffledCopy(rand *rand.Rand, items []reflect.Value) []reflect.Value { - newItems := make([]reflect.Value, len(items)) +func shuffledCopy(rand *rand.Rand, items []keyedGroupValue) []keyedGroupValue { + newItems := make([]keyedGroupValue, len(items)) for i, j := range rand.Perm(len(items)) { newItems[i] = items[j] } diff --git a/decorate.go b/decorate.go index 5cc500d7..4c4aabb9 100644 --- a/decorate.go +++ b/decorate.go @@ -309,13 +309,15 @@ func findResultKeys(r resultList) ([]key, error) { case resultSingle: keys = append(keys, key{t: innerResult.Type, name: innerResult.Name}) case resultGrouped: - if innerResult.Type.Kind() != reflect.Slice { + isMap := innerResult.Type.Kind() == reflect.Map && innerResult.Type.Key().Kind() == reflect.String + isSlice := innerResult.Type.Kind() == reflect.Slice + if !isMap && !isSlice { return nil, newErrInvalidInput("decorating a value group requires decorating the entire value group, not a single value", nil) } keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) case resultObject: for _, f := range innerResult.Fields { - q = append(q, f.Result) + q = append(q, f.Results...) } case resultList: q = append(q, innerResult.Results...) diff --git a/decorate_test.go b/decorate_test.go index 611a0f1c..2eb192f9 100644 --- a/decorate_test.go +++ b/decorate_test.go @@ -216,6 +216,64 @@ func TestDecorateSuccess(t *testing.T) { })) }) + t.Run("map is treated as an ordinary dependency without group tag, named or unnamed, and passes through multiple scopes", func(t *testing.T) { + type params struct { + dig.In + + Strings1 map[string]string + Strings2 map[string]string `name:"strings2"` + } + + type childResult struct { + dig.Out + + Strings1 map[string]string + Strings2 map[string]string `name:"strings2"` + } + + type A map[string]string + type B map[string]string + + parent := digtest.New(t) + parent.RequireProvide(func() map[string]string { return map[string]string{"key1": "val1", "key2": "val2"} }) + parent.RequireProvide(func() map[string]string { return map[string]string{"key1": "val21", "key2": "val22"} }, dig.Name("strings2")) + + parent.RequireProvide(func(p params) A { return A(p.Strings1) }) + parent.RequireProvide(func(p params) B { return B(p.Strings2) }) + + child := parent.Scope("child") + + parent.RequireDecorate(func(p params) childResult { + res := childResult{Strings1: make(map[string]string, len(p.Strings1))} + for k, s := range p.Strings1 { + res.Strings1[k] = strings.ToUpper(s) + } + res.Strings2 = p.Strings2 + return res + }) + + child.RequireDecorate(func(p params) childResult { + res := childResult{Strings2: make(map[string]string, len(p.Strings2))} + for k, s := range p.Strings2 { + res.Strings2[k] = strings.ToUpper(s) + } + res.Strings1 = p.Strings1 + res.Strings1["key3"] = "newval" + return res + }) + + require.NoError(t, child.Invoke(func(p params) { + require.Len(t, p.Strings1, 3) + assert.Equal(t, "VAL1", p.Strings1["key1"]) + assert.Equal(t, "VAL2", p.Strings1["key2"]) + assert.Equal(t, "newval", p.Strings1["key3"]) + require.Len(t, p.Strings2, 2) + assert.Equal(t, "VAL21", p.Strings2["key1"]) + assert.Equal(t, "VAL22", p.Strings2["key2"]) + + })) + + }) t.Run("decorate values in soft group", func(t *testing.T) { type params struct { dig.In @@ -394,6 +452,46 @@ func TestDecorateSuccess(t *testing.T) { assert.Equal(t, `[]string[group = "animals"]`, info.Inputs[0].String()) }) + t.Run("decorate with map value groups", func(t *testing.T) { + type Params struct { + dig.In + + Animals map[string]string `group:"animals"` + } + + type Result struct { + dig.Out + + Animals map[string]string `group:"animals"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "dog" }, dig.Name("animal1"), dig.Group("animals")) + c.RequireProvide(func() string { return "cat" }, dig.Name("animal2"), dig.Group("animals")) + c.RequireProvide(func() string { return "gopher" }, dig.Name("animal3"), dig.Group("animals")) + + var info dig.DecorateInfo + c.RequireDecorate(func(p Params) Result { + animals := p.Animals + for k, v := range animals { + animals[k] = "good " + v + } + return Result{ + Animals: animals, + } + }, dig.FillDecorateInfo(&info)) + + c.RequireInvoke(func(p Params) { + assert.Len(t, p.Animals, 3) + assert.Equal(t, "good dog", p.Animals["animal1"]) + assert.Equal(t, "good cat", p.Animals["animal2"]) + assert.Equal(t, "good gopher", p.Animals["animal3"]) + }) + + require.Equal(t, 1, len(info.Inputs)) + assert.Equal(t, `map[string]string[group = "animals"]`, info.Inputs[0].String()) + }) + t.Run("decorate with optional parameter", func(t *testing.T) { c := digtest.New(t) @@ -919,6 +1017,7 @@ func TestMultipleDecorates(t *testing.T) { assert.ElementsMatch(t, []int{2, 3, 4}, a.Values) }) }) + } func TestFillDecorateInfoString(t *testing.T) { diff --git a/dig_test.go b/dig_test.go index 647c1494..503224a7 100644 --- a/dig_test.go +++ b/dig_test.go @@ -28,6 +28,7 @@ import ( "math/rand" "os" "reflect" + "strings" "testing" "time" @@ -796,6 +797,53 @@ func TestEndToEndSuccess(t *testing.T) { assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match") }) + t.Run("multiple As with Group and Name", func(t *testing.T) { + c := digtest.New(t) + expectedNames := []string{"inst1", "inst2"} + expectedStrs := []string{"foo", "bar"} + for i, s := range expectedStrs { + s := s + c.RequireProvide(func() *bytes.Buffer { + return bytes.NewBufferString(s) + }, dig.Group("buffs"), dig.Name(expectedNames[i]), + dig.As(new(io.Reader), new(io.Writer))) + } + + type in struct { + dig.In + + Reader1 io.Reader `name:"inst1"` + Reader2 io.Reader `name:"inst2"` + Readers []io.Reader `group:"buffs"` + Writers []io.Writer `group:"buffs"` + } + + var actualStrs []string + var actualStrsName []string + + c.RequireInvoke(func(got in) { + require.Len(t, got.Readers, 2) + buf := make([]byte, 3) + for i, r := range got.Readers { + _, err := r.Read(buf) + require.NoError(t, err) + actualStrs = append(actualStrs, string(buf)) + // put the text back + got.Writers[i].Write(buf) + } + _, err := got.Reader1.Read(buf) + require.NoError(t, err) + actualStrsName = append(actualStrsName, string(buf)) + _, err = got.Reader2.Read(buf) + require.NoError(t, err) + actualStrsName = append(actualStrsName, string(buf)) + require.Len(t, got.Writers, 2) + }) + + assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match") + assert.ElementsMatch(t, actualStrsName, expectedStrs, "names: list of strings provided must match") + }) + t.Run("As same interface", func(t *testing.T) { c := digtest.New(t) c.RequireProvide(func() io.Reader { @@ -1145,6 +1193,48 @@ func TestGroups(t *testing.T) { }) }) + t.Run("values are provided; coexist with name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value int `group:"val"` + } + + type out2 struct { + dig.Out + + Value int `name:"inst1" group:"val"` + } + + provide := func(i int) { + c.RequireProvide(func() out { + return out{Value: i} + }) + } + + provide(1) + provide(2) + provide(3) + + c.RequireProvide(func() out2 { + return out2{Value: 4} + }) + + type in struct { + dig.In + + SingleValue int `name:"inst1"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{1, 2, 3, 4}, i.Values) + assert.Equal(t, 4, i.SingleValue) + }) + }) + t.Run("groups are provided via option", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1169,6 +1259,57 @@ func TestGroups(t *testing.T) { }) }) + t.Run("groups are provided via option; coexist with name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + provide := func(i int) { + c.RequireProvide(func() int { + return i + }, dig.Group("val")) + } + + provide(1) + provide(2) + provide(3) + + c.RequireProvide(func() int { + return 4 + }, dig.Group("val"), dig.Name("inst1")) + + type in struct { + dig.In + + SingleValue int `name:"inst1"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{1, 2, 3, 4}, i.Values) + assert.Equal(t, 4, i.SingleValue) + }) + }) + + t.Run("provide multiple with the same name and group but different type", func(t *testing.T) { + c := digtest.New(t) + type A struct{} + type B struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *B `name:"foo" group:"foos"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + c.RequireProvide(func() ret2 { + return ret2{B: &B{}} + }) + }) + t.Run("different types may be grouped", func(t *testing.T) { c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) @@ -1447,169 +1588,1301 @@ func TestGroups(t *testing.T) { assert.Equal(t, gaveErr, dig.RootCause(err)) }) - t.Run("flatten collects slices", func(t *testing.T) { - c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + t.Run("flatten collects slices", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value []int `group:"val,flatten"` + } + + provide := func(i []int) { + c.RequireProvide(func() out { + return out{Value: i} + }) + } + + provide([]int{1, 2}) + provide([]int{3, 4}) + + type in struct { + dig.In + + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 4, 1}, i.Values) + }) + }) + + t.Run("flatten collects slices but also handles name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out1 struct { + dig.Out + + Value []int `name:"foo1" group:"val,flatten"` + } + + type out2 struct { + dig.Out + + Value []int `name:"foo2" group:"val,flatten"` + } + + c.RequireProvide(func() out1 { + return out1{Value: []int{1, 2}} + }) + + c.RequireProvide(func() out2 { + return out2{Value: []int{3, 4}} + }) + + type in struct { + dig.In + + NotFlattenedSlice1 []int `name:"foo1"` + NotFlattenedSlice2 []int `name:"foo2"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 4, 1}, i.Values) + assert.Equal(t, []int{1, 2}, i.NotFlattenedSlice1) + assert.Equal(t, []int{3, 4}, i.NotFlattenedSlice2) + }) + }) + + t.Run("flatten via option", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + c.RequireProvide(func() []int { + return []int{1, 2, 3} + }, dig.Group("val,flatten")) + + type in struct { + dig.In + + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + }) + }) + + t.Run("flatten via option also handles name", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + c.RequireProvide(func() []int { + return []int{1, 2} + }, dig.Group("val,flatten"), dig.Name("foo1")) + + c.RequireProvide(func() []int { + return []int{3} + }, dig.Group("val,flatten"), dig.Name("foo2")) + + type in struct { + dig.In + + NotFlattenedSlice1 []int `name:"foo1"` + NotFlattenedSlice2 []int `name:"foo2"` + Values []int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, []int{1, 2}, i.NotFlattenedSlice1) + assert.Equal(t, []int{3}, i.NotFlattenedSlice2) + }) + }) + + t.Run("flatten via option error if not a slice", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + err := c.Provide(func() int { return 1 }, dig.Group("val,flatten")) + require.Error(t, err, "failed to provide") + assert.Contains(t, err.Error(), "flatten can be applied to slices only") + }) + + t.Run("a soft value group provider is not called when only that value group is consumed", func(t *testing.T) { + type Param struct { + dig.In + + Values []string `group:"foo,soft"` + } + type Result struct { + dig.Out + + Value string `group:"foo"` + } + c := digtest.New(t) + + c.RequireProvide(func() (Result, int) { + require.FailNow(t, "this function should not be called") + return Result{Value: "sad times"}, 20 + }) + + c.RequireInvoke(func(p Param) { + assert.ElementsMatch(t, []string{}, p.Values) + }) + }) + + t.Run("soft value group is provided", func(t *testing.T) { + type Param1 struct { + dig.In + + Values []string `group:"foo,soft"` + } + type Result struct { + dig.Out + + Value1 string `group:"foo"` + Value2 int + } + + c := digtest.New(t) + c.RequireProvide(func() Result { return Result{Value1: "a", Value2: 2} }) + c.RequireProvide(func() string { return "b" }, dig.Group("foo")) + + // The only value that must be in the group is the one that's provided + // because it would be provided anyways by another dependency, in + // this case we need an int, so the first constructor is called, and + // this provides a string, which is the one in the group + c.RequireInvoke(func(p2 int, p1 Param1) { + assert.ElementsMatch(t, []string{"a"}, p1.Values) + }) + }) + + t.Run("two soft group values provided by one constructor", func(t *testing.T) { + type param struct { + dig.In + + Value1 []string `group:"foo,soft"` + Value2 []int `group:"bar,soft"` + Value3 float32 + } + + type result struct { + dig.Out + + Result1 []string `group:"foo,flatten"` + Result2 int `group:"bar"` + } + c := digtest.New(t) + + c.RequireProvide(func() result { + return result{ + Result1: []string{"a", "b", "c"}, + Result2: 4, + } + }) + c.RequireProvide(func() float32 { return 3.1416 }) + + c.RequireInvoke(func(p param) { + assert.ElementsMatch(t, []string{}, p.Value1) + assert.ElementsMatch(t, []int{}, p.Value2) + assert.Equal(t, float32(3.1416), p.Value3) + }) + }) + t.Run("soft in a result value group", func(t *testing.T) { + c := digtest.New(t) + err := c.Provide(func() int { return 10 }, dig.Group("foo,soft")) + require.Error(t, err, "failed to privide") + assert.Contains(t, err.Error(), "cannot use soft with result value groups") + }) + t.Run("value group provided after a hard dependency is provided", func(t *testing.T) { + type Param struct { + dig.In + + Value []string `group:"foo,soft"` + } + + type Result struct { + dig.Out + + Value1 string `group:"foo"` + } + + c := digtest.New(t) + c.Provide(func() (Result, int) { return Result{Value1: "a"}, 2 }) + + c.RequireInvoke(func(param Param) { + assert.ElementsMatch(t, []string{}, param.Value) + }) + c.RequireInvoke(func(int) {}) + c.RequireInvoke(func(param Param) { + assert.ElementsMatch(t, []string{"a"}, param.Value) + }) + }) + /* map tests */ + t.Run("empty map received without provides", func(t *testing.T) { + c := digtest.New(t) + + type in struct { + dig.In + + Values map[string]int `group:"foo"` + } + + c.RequireInvoke(func(i in) { + require.Empty(t, i.Values) + }) + }) + + t.Run("soft map value groups", func(t *testing.T) { + t.Run("soft map provider not called when only soft group consumed", func(t *testing.T) { + c := digtest.New(t) + + type Result struct { + dig.Out + Value string `name:"val1" group:"handlers"` + } + + // This provider should NOT be called because we're only consuming + // the soft group and there are no other dependencies forcing it to run + c.RequireProvide(func() (Result, int) { + require.FailNow(t, "this function should not be called for soft map groups") + return Result{Value: "should not see this"}, 42 + }) + + type SoftMapConsumer struct { + dig.In + Handlers map[string]string `group:"handlers,soft"` + } + + c.RequireInvoke(func(p SoftMapConsumer) { + assert.Empty(t, p.Handlers, "soft map group should be empty when no providers executed") + }) + }) + + t.Run("soft map gets values from already-executed constructors", func(t *testing.T) { + c := digtest.New(t) + + type HandlerResult struct { + dig.Out + Handler string `name:"handler1" group:"handlers"` + Service int // This forces the constructor to run + } + + // This provider will be called because we need the int service + c.RequireProvide(func() HandlerResult { + return HandlerResult{ + Handler: "executed_handler", + Service: 100, + } + }) + + // Additional provider that won't be executed + type UnexecutedResult struct { + dig.Out + Handler string `name:"handler2" group:"handlers"` + } + c.RequireProvide(func() UnexecutedResult { + require.FailNow(t, "this should not be called") + return UnexecutedResult{Handler: "never_called"} + }) + + type ConsumerParams struct { + dig.In + Service int // This triggers the first provider + SoftHandlerMap map[string]string `group:"handlers,soft"` + SoftHandlerSlice []string `group:"handlers,soft"` + } + + c.RequireInvoke(func(p ConsumerParams) { + assert.Equal(t, 100, p.Service) + + // Soft map should only contain the handler from the executed constructor + assert.Len(t, p.SoftHandlerMap, 1) + assert.Equal(t, "executed_handler", p.SoftHandlerMap["handler1"]) + assert.NotContains(t, p.SoftHandlerMap, "handler2") + + // Verify slice consumption works the same way + assert.Len(t, p.SoftHandlerSlice, 1) + assert.Equal(t, "executed_handler", p.SoftHandlerSlice[0]) + }) + }) + + t.Run("soft map combined with regular map consumption", func(t *testing.T) { + c := digtest.New(t) + + type ServiceResult struct { + dig.Out + Service string `name:"service1" group:"services"` + Config int // Forces execution + } + + c.RequireProvide(func() ServiceResult { + return ServiceResult{ + Service: "auth_service", + Config: 42, + } + }) + + type MultiConsumerParams struct { + dig.In + Config int // Triggers provider + SoftServices map[string]string `group:"services,soft"` + RegularServices map[string]string `group:"services"` + } + + c.RequireInvoke(func(p MultiConsumerParams) { + assert.Equal(t, 42, p.Config) + + // Both should have the same content since the provider was executed + assert.Len(t, p.SoftServices, 1) + assert.Equal(t, "auth_service", p.SoftServices["service1"]) + + assert.Len(t, p.RegularServices, 1) + assert.Equal(t, "auth_service", p.RegularServices["service1"]) + + // They should be equivalent + assert.Equal(t, p.SoftServices, p.RegularServices) + }) + }) + }) + + t.Run("map value group using dig.Name and dig.Group", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + c.RequireProvide(func() int { + return 1 + }, dig.Name("value1"), dig.Group("val")) + c.RequireProvide(func() int { + return 2 + }, dig.Name("value2"), dig.Group("val")) + c.RequireProvide(func() int { + return 3 + }, dig.Name("value3"), dig.Group("val")) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + t.Run("values are provided, map and name and slice", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `name:"value3" group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + Value1 int `name:"value1"` + Value2 int `name:"value2"` + Value3 int `name:"value3"` + Values []int `group:"val"` + ValueMap map[string]int `group:"val"` + } + + c.RequireInvoke(func(i in) { + assert.Equal(t, []int{2, 3, 1}, i.Values) + assert.Equal(t, i.ValueMap["value1"], 1) + assert.Equal(t, i.ValueMap["value2"], 2) + assert.Equal(t, i.ValueMap["value3"], 3) + assert.Equal(t, i.Value1, 1) + assert.Equal(t, i.Value2, 2) + assert.Equal(t, i.Value3, 3) + }) + }) + + t.Run("Every item used in a map must have a named key", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + Value1 int `name:"value1" group:"val"` + Value2 int `name:"value2" group:"val"` + Value3 int `group:"val"` + } + + c.RequireProvide(func() out { + return out{Value1: 1, Value2: 2, Value3: 3} + }) + + type in struct { + dig.In + + ValueMap map[string]int `group:"val"` + } + var called = false + err := c.Invoke(func(i in) { called = true }) + dig.AssertErrorMatches(t, err, + `could not build arguments for function "go.uber.org/dig_test".TestGroups\S+`, + `dig_test.go:\d+`, // file:line + `every entry in a map value groups must have a name, group "val" is missing a name`) + assert.False(t, called, "shouldn't call invoked function when deps aren't available") + }) + + t.Run("map value groups with interface types", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + // Provide different implementations of io.Reader + c.RequireProvide(func() io.Reader { + return strings.NewReader("hello") + }, dig.Name("string_reader"), dig.Group("readers")) + + c.RequireProvide(func() io.Reader { + return bytes.NewBufferString("world") + }, dig.Name("bytes_reader"), dig.Group("readers")) + + type in struct { + dig.In + + StringReader io.Reader `name:"string_reader"` + BytesReader io.Reader `name:"bytes_reader"` + ReaderSlice []io.Reader `group:"readers"` + ReaderMap map[string]io.Reader `group:"readers"` + } + + c.RequireInvoke(func(i in) { + // Test individual named access + assert.NotNil(t, i.StringReader) + assert.NotNil(t, i.BytesReader) + + // Test slice access (traditional) + require.Len(t, i.ReaderSlice, 2) + + // Test map access (new feature) + require.Len(t, i.ReaderMap, 2) + assert.NotNil(t, i.ReaderMap["string_reader"]) + assert.NotNil(t, i.ReaderMap["bytes_reader"]) + + // Verify we can actually use the interface methods + buf := make([]byte, 5) + n, err := i.ReaderMap["string_reader"].Read(buf) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(buf)) + }) + }) + + t.Run("map value groups with interface types using struct annotations", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type out struct { + dig.Out + + StringReader io.Reader `name:"str_reader" group:"readers"` + BytesReader io.Reader `name:"buf_reader" group:"readers"` + } + + c.RequireProvide(func() out { + return out{ + StringReader: strings.NewReader("test1"), + BytesReader: bytes.NewBufferString("test2"), + } + }) + + type in struct { + dig.In + + ReaderMap map[string]io.Reader `group:"readers"` + } + + c.RequireInvoke(func(i in) { + require.Len(t, i.ReaderMap, 2) + + // Test that we got the right implementations + buf1 := make([]byte, 5) + buf2 := make([]byte, 5) + + i.ReaderMap["str_reader"].Read(buf1) + i.ReaderMap["buf_reader"].Read(buf2) + + assert.Equal(t, "test1", string(buf1)) + assert.Equal(t, "test2", string(buf2)) + }) + }) + + t.Run("map value groups with pointer types", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type MyStruct struct { + Value string + } + + // Provide pointers using function options + c.RequireProvide(func() *MyStruct { + return &MyStruct{Value: "first"} + }, dig.Name("struct1"), dig.Group("structs")) + + c.RequireProvide(func() *MyStruct { + return &MyStruct{Value: "second"} + }, dig.Name("struct2"), dig.Group("structs")) + + type in struct { + dig.In + + Struct1 *MyStruct `name:"struct1"` + Struct2 *MyStruct `name:"struct2"` + StructSlice []*MyStruct `group:"structs"` + StructMap map[string]*MyStruct `group:"structs"` + } + + c.RequireInvoke(func(i in) { + // Test individual named access + require.NotNil(t, i.Struct1) + require.NotNil(t, i.Struct2) + assert.Equal(t, "first", i.Struct1.Value) + assert.Equal(t, "second", i.Struct2.Value) + + // Test slice access + require.Len(t, i.StructSlice, 2) + + // Test map access + require.Len(t, i.StructMap, 2) + require.NotNil(t, i.StructMap["struct1"]) + require.NotNil(t, i.StructMap["struct2"]) + assert.Equal(t, "first", i.StructMap["struct1"].Value) + assert.Equal(t, "second", i.StructMap["struct2"].Value) + + // Verify pointers are the same instances + assert.Same(t, i.Struct1, i.StructMap["struct1"]) + assert.Same(t, i.Struct2, i.StructMap["struct2"]) + }) + }) + + t.Run("map value groups with pointer types using struct annotations", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + type ConfigItem struct { + Name string + Value int + } + + type out struct { + dig.Out + + Config1 *ConfigItem `name:"db_config" group:"configs"` + Config2 *ConfigItem `name:"cache_config" group:"configs"` + Config3 *ConfigItem `name:"auth_config" group:"configs"` + } + + c.RequireProvide(func() out { + return out{ + Config1: &ConfigItem{Name: "database", Value: 5432}, + Config2: &ConfigItem{Name: "cache", Value: 6379}, + Config3: &ConfigItem{Name: "auth", Value: 8080}, + } + }) + + type in struct { + dig.In + + ConfigSlice []*ConfigItem `group:"configs"` + ConfigMap map[string]*ConfigItem `group:"configs"` + } + + c.RequireInvoke(func(i in) { + // Test slice access + require.Len(t, i.ConfigSlice, 3) + + // Test map access with meaningful keys + require.Len(t, i.ConfigMap, 3) + + dbConfig := i.ConfigMap["db_config"] + require.NotNil(t, dbConfig) + assert.Equal(t, "database", dbConfig.Name) + assert.Equal(t, 5432, dbConfig.Value) + + cacheConfig := i.ConfigMap["cache_config"] + require.NotNil(t, cacheConfig) + assert.Equal(t, "cache", cacheConfig.Name) + assert.Equal(t, 6379, cacheConfig.Value) + + authConfig := i.ConfigMap["auth_config"] + require.NotNil(t, authConfig) + assert.Equal(t, "auth", authConfig.Name) + assert.Equal(t, 8080, authConfig.Value) + }) + }) + + t.Run("map value groups with dig.As interface transformation", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + // Provide concrete types that get transformed to interfaces via dig.As + c.RequireProvide(func() *bytes.Buffer { + return bytes.NewBufferString("buffer1") + }, dig.Name("buf1"), dig.Group("readers"), dig.As(new(io.Reader))) + + c.RequireProvide(func() *strings.Reader { + return strings.NewReader("reader1") + }, dig.Name("str1"), dig.Group("readers"), dig.As(new(io.Reader))) + + type in struct { + dig.In + + // Individual named access should work + Buf1 io.Reader `name:"buf1"` + Str1 io.Reader `name:"str1"` + + // Traditional slice access should work + ReaderSlice []io.Reader `group:"readers"` + + // NEW: Map access with dig.As should work + ReaderMap map[string]io.Reader `group:"readers"` + } + + c.RequireInvoke(func(i in) { + // Test individual named access works + require.NotNil(t, i.Buf1) + require.NotNil(t, i.Str1) + + // Test slice access works + require.Len(t, i.ReaderSlice, 2) + + // Test map access works with dig.As + require.Len(t, i.ReaderMap, 2) + + buf1Reader := i.ReaderMap["buf1"] + require.NotNil(t, buf1Reader) + + str1Reader := i.ReaderMap["str1"] + require.NotNil(t, str1Reader) + + // Verify we can actually use the interface methods + buf := make([]byte, 7) + n, err := buf1Reader.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 7, n) + assert.Equal(t, "buffer1", string(buf)) + + buf2 := make([]byte, 7) + n2, err2 := str1Reader.Read(buf2) + assert.NoError(t, err2) + assert.Equal(t, 7, n2) + assert.Equal(t, "reader1", string(buf2)) + + // Verify same instances across access patterns + assert.Same(t, i.Buf1, i.ReaderMap["buf1"]) + assert.Same(t, i.Str1, i.ReaderMap["str1"]) + }) + }) + + t.Run("map value groups with dig.As multiple interface transformation", func(t *testing.T) { + c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) + + // Provide a type that implements multiple interfaces + c.RequireProvide(func() *bytes.Buffer { + return bytes.NewBufferString("multi") + }, dig.Name("multi_buf"), dig.Group("readwriters"), + dig.As(new(io.Reader), new(io.Writer))) + + c.RequireProvide(func() *bytes.Buffer { + return bytes.NewBufferString("another") + }, dig.Name("another_buf"), dig.Group("readwriters"), + dig.As(new(io.Reader), new(io.Writer))) + + type in struct { + dig.In + + // Access as readers + ReaderMap map[string]io.Reader `group:"readwriters"` + + // Access as writers + WriterMap map[string]io.Writer `group:"readwriters"` + + // Access as slice of readers + ReaderSlice []io.Reader `group:"readwriters"` + } + + c.RequireInvoke(func(i in) { + // Test both interface maps work + require.Len(t, i.ReaderMap, 2) + require.Len(t, i.WriterMap, 2) + require.Len(t, i.ReaderSlice, 2) + + // Test we can read from the reader interface + multiBufReader := i.ReaderMap["multi_buf"] + require.NotNil(t, multiBufReader) + + readBuf := make([]byte, 5) + n, err := multiBufReader.Read(readBuf) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "multi", string(readBuf)) + + // Test we can write to the writer interface + multiBufWriter := i.WriterMap["multi_buf"] + require.NotNil(t, multiBufWriter) + + n2, err2 := multiBufWriter.Write([]byte("_test")) + assert.NoError(t, err2) + assert.Equal(t, 5, n2) + + // Verify both interfaces point to the same underlying object + // We can't use assert.Same here since they're different interface values + // but we can verify they affect the same buffer + anotherReader := i.ReaderMap["another_buf"] + anotherWriter := i.WriterMap["another_buf"] + + // Write something + anotherWriter.Write([]byte("_added")) + + // Read it back (should include both original + added) + fullBuf := make([]byte, 13) + n3, err3 := anotherReader.Read(fullBuf) + assert.NoError(t, err3) + assert.Equal(t, 13, n3) + assert.Equal(t, "another_added", string(fullBuf)) + }) + }) + + t.Run("slice decorator works with unnamed groups", func(t *testing.T) { + c := digtest.New(t) + + // Provide values WITHOUT names (unnamed group) + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) + c.RequireProvide(func() int { return 20 }, dig.Group("numbers")) + + // Register slice decorator using struct parameters - this should work for unnamed groups + type DecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type DecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + var decoratorCalled bool + c.RequireDecorate(func(p DecorateParams) DecorateResult { + decoratorCalled = true + t.Logf("Slice decorator called with: %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, n := range p.Nums { + result[i] = n * 100 + } + return DecorateResult{Nums: result} + }) + + // Consume as slice - should get decorated values + type in struct { + dig.In + Nums []int `group:"numbers"` + } + + err := c.Invoke(func(i in) { + t.Logf("Got slice: %v", i.Nums) + t.Logf("Decorator called: %v", decoratorCalled) + if decoratorCalled { + // Should be [1000, 2000] from decoration + assert.ElementsMatch(t, []int{1000, 2000}, i.Nums) + } else { + // Original values [10, 20] - this means decorator wasn't called + assert.ElementsMatch(t, []int{10, 20}, i.Nums) + } + }) + require.NoError(t, err) + }) + + t.Run("slice decorators forbidden with named groups", func(t *testing.T) { + c := digtest.New(t) + + // Provide values with names and group (this creates named group) + c.RequireProvide(func() int { return 10 }, dig.Name("first"), dig.Group("numbers")) + c.RequireProvide(func() int { return 20 }, dig.Name("second"), dig.Group("numbers")) + + // Register slice decorator using struct parameters - this should succeed at registration time + type DecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type DecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + c.RequireDecorate(func(p DecorateParams) DecorateResult { + // Note: This decorator will be called but its results will be blocked + t.Logf("Slice decorator called (results will be blocked): %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, n := range p.Nums { + result[i] = n * 100 + } + return DecorateResult{Nums: result} + }) + + // Try to consume as slice - should fail with validation error + type in struct { + dig.In + Nums []int `group:"numbers"` + } + err := c.Invoke(func(i in) { + t.Logf("Should not reach here - slice consumption should fail") + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "cannot use slice decoration for value group \"numbers\"") + require.Contains(t, err.Error(), "group contains named values") + require.Contains(t, err.Error(), "use map[string]T decorator instead") + }) + + t.Run("decoration edge cases", func(t *testing.T) { + t.Run("multiple slice decorators forbidden", func(t *testing.T) { + c := digtest.New(t) + + // Provide unnamed values + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) + c.RequireProvide(func() int { return 20 }, dig.Group("numbers")) + + // Register first decorator + type DecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type DecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + c.RequireDecorate(func(p DecorateParams) DecorateResult { + return DecorateResult{Nums: p.Nums} + }) + + // Try to register second decorator - should fail + err := c.Decorate(func(p DecorateParams) DecorateResult { + return DecorateResult{Nums: p.Nums} + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "already decorated") + }) + + t.Run("map decorator with unnamed values fails", func(t *testing.T) { + c := digtest.New(t) + + // Provide VALUES WITHOUT NAMES + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) + c.RequireProvide(func() int { return 20 }, dig.Group("numbers")) + + // Register MAP decorator for unnamed group - registration succeeds + type MapDecorateParams struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + type MapDecorateResult struct { + dig.Out + NumMap map[string]int `group:"numbers"` + } + c.RequireDecorate(func(p MapDecorateParams) MapDecorateResult { + t.Logf("This should never be called") + return MapDecorateResult{NumMap: p.NumMap} + }) + + // Try to consume as slice - should fail because decorator can't run + type SliceConsumer struct { + dig.In + Nums []int `group:"numbers"` + } + err := c.Invoke(func(s SliceConsumer) { + t.Logf("Should not reach here") + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "every entry in a map value groups must have a name") + }) + + t.Run("mixed named and unnamed values with slice decorator", func(t *testing.T) { + c := digtest.New(t) + + // Mix of named and unnamed values + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) // unnamed + c.RequireProvide(func() int { return 20 }, dig.Name("twenty"), dig.Group("numbers")) // named + c.RequireProvide(func() int { return 30 }, dig.Group("numbers")) // unnamed + + // Register slice decorator + type DecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type DecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + c.RequireDecorate(func(p DecorateParams) DecorateResult { + // Note: This decorator will be called but its results will be blocked + t.Logf("Slice decorator called (results will be blocked): %v", p.Nums) + return DecorateResult{Nums: p.Nums} + }) + + // Try to consume as slice - should fail due to validation + type SliceConsumer struct { + dig.In + Nums []int `group:"numbers"` + } + err := c.Invoke(func(s SliceConsumer) { + t.Logf("Should not reach here") + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "cannot use slice decoration for value group") + require.Contains(t, err.Error(), "group contains named values") + }) + + t.Run("empty group decoration", func(t *testing.T) { + c := digtest.New(t) - type out struct { - dig.Out + // No providers yet, just register decorator + type DecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type DecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + var decoratorCalled bool + c.RequireDecorate(func(p DecorateParams) DecorateResult { + decoratorCalled = true + t.Logf("Decorator called with empty group: %v", p.Nums) + return DecorateResult{Nums: []int{999}} // Add a value + }) - Value []int `group:"val,flatten"` - } + // Consume the decorated empty group + type SliceConsumer struct { + dig.In + Nums []int `group:"numbers"` + } + err := c.Invoke(func(s SliceConsumer) { + t.Logf("Got: %v", s.Nums) + t.Logf("Decorator was called: %v", decoratorCalled) + assert.ElementsMatch(t, []int{999}, s.Nums) + }) - provide := func(i []int) { - c.RequireProvide(func() out { - return out{Value: i} + require.NoError(t, err) + }) + + t.Run("cross-scope decoration", func(t *testing.T) { + c := digtest.New(t) + + // Parent scope: provide values and decorator + c.RequireProvide(func() int { return 10 }, dig.Name("ten"), dig.Group("numbers")) + + // Parent scope: map decorator + type MapDecorateParams struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + type MapDecorateResult struct { + dig.Out + NumMap map[string]int `group:"numbers"` + } + c.RequireDecorate(func(p MapDecorateParams) MapDecorateResult { + t.Logf("Parent decorator called with: %v", p.NumMap) + result := make(map[string]int) + for k, v := range p.NumMap { + result[k] = v * 100 // Multiply by 100 + } + return MapDecorateResult{NumMap: result} }) - } - provide([]int{1, 2}) - provide([]int{3, 4}) + // Child scope + child := c.Scope("child") + child.RequireProvide(func() int { return 20 }, dig.Name("twenty"), dig.Group("numbers")) - type in struct { - dig.In + // Child scope consumption - should get decorated values from parent + child values + type MapConsumer struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + err := child.Invoke(func(m MapConsumer) { + t.Logf("Child got map: %v", m.NumMap) - Values []int `group:"val"` - } + // Let's observe what actually happens + if val, hasParent := m.NumMap["ten"]; hasParent { + t.Logf("Parent value 'ten': %d (expected 1000 if decorated)", val) + } + if val, hasChild := m.NumMap["twenty"]; hasChild { + t.Logf("Child value 'twenty': %d (what should this be?)", val) + } - c.RequireInvoke(func(i in) { - assert.Equal(t, []int{2, 3, 4, 1}, i.Values) + // Document the actual behavior we observe + t.Logf("Total values in child scope: %d", len(m.NumMap)) + }) + + require.NoError(t, err) }) - }) - t.Run("flatten via option", func(t *testing.T) { - c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) - c.RequireProvide(func() []int { - return []int{1, 2, 3} - }, dig.Group("val,flatten")) + t.Run("cross-scope decoration with child decorator", func(t *testing.T) { + c := digtest.New(t) - type in struct { - dig.In + // Parent scope: provide values + c.RequireProvide(func() int { return 10 }, dig.Name("ten"), dig.Group("numbers")) - Values []int `group:"val"` - } + // Parent scope: map decorator + type MapDecorateParams struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + type MapDecorateResult struct { + dig.Out + NumMap map[string]int `group:"numbers"` + } + parentDecorator := func(p MapDecorateParams) MapDecorateResult { + t.Logf("Parent decorator called with: %v", p.NumMap) + result := make(map[string]int) + for k, v := range p.NumMap { + result[k] = v * 100 // Multiply by 100 + } + return MapDecorateResult{NumMap: result} + } + c.RequireDecorate(parentDecorator) + + // Child scope + child := c.Scope("child") + child.RequireProvide(func() int { return 20 }, dig.Name("twenty"), dig.Group("numbers")) + + // Child scope: ALSO add a decorator + childDecorator := func(p MapDecorateParams) MapDecorateResult { + t.Logf("Child decorator called with: %v", p.NumMap) + result := make(map[string]int) + for k, v := range p.NumMap { + result[k] = v + 1000 // Add 1000 + } + return MapDecorateResult{NumMap: result} + } + child.RequireDecorate(childDecorator) - c.RequireInvoke(func(i in) { - assert.Equal(t, []int{2, 3, 1}, i.Values) + // Test: child consumption should get both parent values + child values, all decorated + type MapConsumer struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + err := child.Invoke(func(m MapConsumer) { + t.Logf("Child got map: %v", m.NumMap) + + // Based on existing slice tests, we should get: + // - Parent values decorated by both parent + child decorators + // - Child values decorated by child decorator only + t.Logf("Total values in child scope: %d", len(m.NumMap)) + }) + + require.NoError(t, err) }) - }) - t.Run("flatten via option error if not a slice", func(t *testing.T) { - c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0)))) - err := c.Provide(func() int { return 1 }, dig.Group("val,flatten")) - require.Error(t, err, "failed to provide") - assert.Contains(t, err.Error(), "flatten can be applied to slices only") - }) + t.Run("cross-scope slice decoration behavior (baseline)", func(t *testing.T) { + c := digtest.New(t) - t.Run("a soft value group provider is not called when only that value group is consumed", func(t *testing.T) { - type Param struct { - dig.In + // Parent scope: provide values + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) - Values []string `group:"foo,soft"` - } - type Result struct { - dig.Out + // Parent scope: slice decorator + type SliceDecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type SliceDecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + parentDecorator := func(p SliceDecorateParams) SliceDecorateResult { + t.Logf("Parent slice decorator called with: %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, v := range p.Nums { + result[i] = v * 100 // Multiply by 100 + } + return SliceDecorateResult{Nums: result} + } + c.RequireDecorate(parentDecorator) - Value string `group:"foo"` - } - c := digtest.New(t) + // Child scope + child := c.Scope("child") + child.RequireProvide(func() int { return 20 }, dig.Group("numbers")) - c.RequireProvide(func() (Result, int) { - require.FailNow(t, "this function should not be called") - return Result{Value: "sad times"}, 20 - }) + // Test: does child get parent values + child values? + type SliceConsumer struct { + dig.In + Nums []int `group:"numbers"` + } + err := child.Invoke(func(s SliceConsumer) { + t.Logf("Child got slice: %v", s.Nums) + t.Logf("Total values in child scope: %d", len(s.Nums)) - c.RequireInvoke(func(p Param) { - assert.ElementsMatch(t, []string{}, p.Values) + // Document what actually happens with slices: + // Does the child get BOTH decorated parent values AND undecorated child values? + // Or does it only get decorated parent values like maps? + }) + + require.NoError(t, err) }) - }) - t.Run("soft value group is provided", func(t *testing.T) { - type Param1 struct { - dig.In + t.Run("cross-scope slice with child decorator (baseline)", func(t *testing.T) { + c := digtest.New(t) - Values []string `group:"foo,soft"` - } - type Result struct { - dig.Out + // Parent scope: provide values + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) - Value1 string `group:"foo"` - Value2 int - } + // Parent scope: slice decorator + type SliceDecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type SliceDecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + parentDecorator := func(p SliceDecorateParams) SliceDecorateResult { + t.Logf("Parent slice decorator called with: %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, v := range p.Nums { + result[i] = v * 100 // Multiply by 100 + } + return SliceDecorateResult{Nums: result} + } + c.RequireDecorate(parentDecorator) + + // Child scope + child := c.Scope("child") + child.RequireProvide(func() int { return 20 }, dig.Group("numbers")) + + // Child scope: ALSO add a decorator + childDecorator := func(p SliceDecorateParams) SliceDecorateResult { + t.Logf("Child slice decorator called with: %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, v := range p.Nums { + result[i] = v + 1000 // Add 1000 + } + return SliceDecorateResult{Nums: result} + } + child.RequireDecorate(childDecorator) - c := digtest.New(t) - c.RequireProvide(func() Result { return Result{Value1: "a", Value2: 2} }) - c.RequireProvide(func() string { return "b" }, dig.Group("foo")) + // Test: child consumption behavior with child decorator + type SliceConsumer struct { + dig.In + Nums []int `group:"numbers"` + } + err := child.Invoke(func(s SliceConsumer) { + t.Logf("Child got slice: %v", s.Nums) + t.Logf("Total values in child scope: %d", len(s.Nums)) - // The only value that must be in the group is the one that's provided - // because it would be provided anyways by another dependency, in - // this case we need an int, so the first constructor is called, and - // this provides a string, which is the one in the group - c.RequireInvoke(func(p2 int, p1 Param1) { - assert.ElementsMatch(t, []string{"a"}, p1.Values) + // Compare this to our map behavior: + // Map: parent decorator sees only parent values, child sees decorated parent only + // Slice: what does each decorator see? + }) + + require.NoError(t, err) }) - }) - t.Run("two soft group values provided by one constructor", func(t *testing.T) { - type param struct { - dig.In + t.Run("existing test pattern - all values in parent scope", func(t *testing.T) { + c := digtest.New(t) - Value1 []string `group:"foo,soft"` - Value2 []int `group:"bar,soft"` - Value3 float32 - } + // ALL values provided in parent scope (like existing tests) + c.RequireProvide(func() int { return 10 }, dig.Group("numbers")) + c.RequireProvide(func() int { return 20 }, dig.Group("numbers")) - type result struct { - dig.Out + // Parent scope: slice decorator + type SliceDecorateParams struct { + dig.In + Nums []int `group:"numbers"` + } + type SliceDecorateResult struct { + dig.Out + Nums []int `group:"numbers"` + } + parentDecorator := func(p SliceDecorateParams) SliceDecorateResult { + t.Logf("Parent slice decorator called with: %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, v := range p.Nums { + result[i] = v + 1 // Add 1 + } + return SliceDecorateResult{Nums: result} + } + c.RequireDecorate(parentDecorator) - Result1 []string `group:"foo,flatten"` - Result2 int `group:"bar"` - } - c := digtest.New(t) + // Child scope - no new values, just add decorator + child := c.Scope("child") - c.RequireProvide(func() result { - return result{ - Result1: []string{"a", "b", "c"}, - Result2: 4, + // Child scope: add decorator + childDecorator := func(p SliceDecorateParams) SliceDecorateResult { + t.Logf("Child slice decorator called with: %v", p.Nums) + result := make([]int, len(p.Nums)) + for i, v := range p.Nums { + result[i] = v + 1 // Add 1 again + } + return SliceDecorateResult{Nums: result} } - }) - c.RequireProvide(func() float32 { return 3.1416 }) + child.RequireDecorate(childDecorator) - c.RequireInvoke(func(p param) { - assert.ElementsMatch(t, []string{}, p.Value1) - assert.ElementsMatch(t, []int{}, p.Value2) - assert.Equal(t, float32(3.1416), p.Value3) + // Child consumption - this should match existing test pattern + type SliceConsumer struct { + dig.In + Nums []int `group:"numbers"` + } + err := child.Invoke(func(s SliceConsumer) { + t.Logf("Child got slice: %v", s.Nums) + t.Logf("Total values in child scope: %d", len(s.Nums)) + + // This should match the existing test: {10, 20} +1 +1 = {12, 22} + // Confirms the pattern: child gets ALL parent values with BOTH decorators applied + }) + + require.NoError(t, err) }) - }) - t.Run("soft in a result value group", func(t *testing.T) { - c := digtest.New(t) - err := c.Provide(func() int { return 10 }, dig.Group("foo,soft")) - require.Error(t, err, "failed to privide") - assert.Contains(t, err.Error(), "cannot use soft with result value groups") - }) - t.Run("value group provided after a hard dependency is provided", func(t *testing.T) { - type Param struct { - dig.In - Value []string `group:"foo,soft"` - } + t.Run("decorator adds new values", func(t *testing.T) { + c := digtest.New(t) - type Result struct { - dig.Out + // Provide base values + c.RequireProvide(func() int { return 10 }, dig.Name("base"), dig.Group("numbers")) - Value1 string `group:"foo"` - } + // Decorator that adds new values to the group + type DecorateParams struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + type DecorateResult struct { + dig.Out + NumMap map[string]int `group:"numbers"` + } + c.RequireDecorate(func(p DecorateParams) DecorateResult { + t.Logf("Decorator called with: %v", p.NumMap) + result := make(map[string]int) + for k, v := range p.NumMap { + result[k] = v + } + result["decorated"] = 999 // Add a new entry + return DecorateResult{NumMap: result} + }) - c := digtest.New(t) - c.Provide(func() (Result, int) { return Result{Value1: "a"}, 2 }) + // Consumer that wants the group + type Consumer struct { + dig.In + NumMap map[string]int `group:"numbers"` + } + err := c.Invoke(func(con Consumer) { + t.Logf("Consumer got: %v", con.NumMap) + assert.Equal(t, 10, con.NumMap["base"]) + assert.Equal(t, 999, con.NumMap["decorated"]) + }) - c.RequireInvoke(func(param Param) { - assert.ElementsMatch(t, []string{}, param.Value) - }) - c.RequireInvoke(func(int) {}) - c.RequireInvoke(func(param Param) { - assert.ElementsMatch(t, []string{"a"}, param.Value) + require.NoError(t, err) }) }) + } // --- END OF END TO END TESTS @@ -2219,21 +3492,6 @@ func TestAsExpectingOriginalType(t *testing.T) { }) } -func TestProvideIncompatibleOptions(t *testing.T) { - t.Parallel() - - t.Run("group and name", func(t *testing.T) { - c := digtest.New(t) - err := c.Provide(func() io.Reader { - t.Fatal("this function must not be called") - return nil - }, dig.Group("foo"), dig.Name("bar")) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use named values with value groups: "+ - `name:"bar" provided with group:"foo"`) - }) -} - type testStruct struct{} func (testStruct) TestMethod(x int) float64 { return float64(x) } @@ -2780,6 +4038,80 @@ func testProvideFailures(t *testing.T, dryRun bool) { ) }) + t.Run("provide multiple instances with the same name and same group", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + err := c.Provide(func() ret2 { + return ret2{A: &A{}} + }) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[0\].A:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + + t.Run("provide multiple instances with the same name and same group using options", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + + c.RequireProvide(func() *A { + return &A{} + }, dig.Group("foos"), dig.Name("foo")) + + err := c.Provide(func() *A { + return &A{} + }, dig.Group("foos"), dig.Name("foo")) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[1\]:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + + t.Run("provide multiple instances with the same name and type but different group", func(t *testing.T) { + c := digtest.New(t, dig.DryRun(dryRun)) + type A struct{} + type ret1 struct { + dig.Out + *A `name:"foo" group:"foos"` + } + type ret2 struct { + dig.Out + *A `name:"foo" group:"foosss"` + } + c.RequireProvide(func() ret1 { + return ret1{A: &A{}} + }) + + err := c.Provide(func() ret2 { + return ret2{A: &A{}} + }) + require.Error(t, err, "expected error on the second provide") + dig.AssertErrorMatches(t, err, + `cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`, + `dig_test.go:\d+`, // file:line + `cannot provide \*dig_test.A\[name="foo"\] from \[0\].A:`, + `already provided by "go.uber.org/dig_test".testProvideFailures\S+`, + ) + }) + t.Run("out with unexported field should error", func(t *testing.T) { c := digtest.New(t, dig.DryRun(dryRun)) diff --git a/doc.go b/doc.go index 77b01d04..916900ad 100644 --- a/doc.go +++ b/doc.go @@ -345,4 +345,73 @@ // Handler []int `group:"server"` // [][]int from dig.In // Handler []int `group:"server,flatten"` // []int from dig.In // } +// +// # Map Value Groups +// +// Added in Dig 1.20. +// +// For named value groups, dig supports consuming values as maps in addition +// to slices. This allows accessing individual values by their names while +// still providing the convenience of working with the entire collection. +// +// To use map value groups, values must be provided with both a name and a +// group. This can be done by combining dig.Name() with dig.Group(): +// +// c.Provide(func() int { return 42 }, dig.Name("answer"), dig.Group("numbers")) +// c.Provide(func() int { return 100 }, dig.Name("perfect"), dig.Group("numbers")) +// +// Or by using result struct tags: +// +// type NumberResult struct { +// dig.Out +// +// Answer int `name:"answer" group:"numbers"` +// Perfect int `name:"perfect" group:"numbers"` +// } +// +// Named value groups can be consumed as maps where the names become keys: +// +// type Params struct { +// dig.In +// +// NumberMap map[string]int `group:"numbers"` // {"answer": 42, "perfect": 100} +// NumberSlice []int `group:"numbers"` // [42, 100] (order unspecified) +// Answer int `name:"answer"` // 42 +// } +// +// Map value groups provide the same flexibility as slice value groups: +// values can be consumed individually by name, as a slice for iteration, +// or as a map for direct key-based access. +// +// Note that only string-keyed maps (map[string]T) are supported, and all +// values in a map value group must have names. +// +// # Decorator Compatibility +// +// Slice decorators (func([]T) []T) cannot be used with named value groups +// because they lose the key information needed to reconstruct maps. +// Attempting to use slice decorators with named value groups will fail with: +// "cannot use slice decoration for value group: group contains named values, +// use map[string]T decorator instead". +// +// This is not a breaking change because named value groups are a new feature - +// previously, dig.Name() and dig.Group() were mutually exclusive. +// +// Use map decorators for named value groups: +// +// type MapDecorator struct { +// dig.In +// Numbers map[string]int `group:"numbers"` +// } +// +// type MapResult struct { +// dig.Out +// Numbers map[string]int `group:"numbers"` +// } +// +// func DecorateNumbers(p MapDecorator) MapResult { +// // Modify the map and return +// return MapResult{Numbers: modifiedMap} +// } +// package dig // import "go.uber.org/dig" diff --git a/graph.go b/graph.go index e08f1f54..c52e78f7 100644 --- a/graph.go +++ b/graph.go @@ -28,7 +28,7 @@ type graphNode struct { } // graphHolder is the dependency graph of the container. -// It saves constructorNodes and paramGroupedSlice (value groups) +// It saves constructorNodes and paramGroupedCollection (value groups) // as nodes in the graph. // It implements the graph interface defined by internal/graph. // It has 1-1 correspondence with the Scope whose graph it represents. @@ -68,7 +68,7 @@ func (gh *graphHolder) EdgesFrom(u int) []int { for _, param := range w.paramList.Params { orders = append(orders, getParamOrder(gh, param)...) } - case *paramGroupedSlice: + case *paramGroupedCollection: providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem()) for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) diff --git a/param.go b/param.go index 07f72e09..baa6ddff 100644 --- a/param.go +++ b/param.go @@ -39,10 +39,12 @@ import ( // paramSingle An explicitly requested type. // paramObject dig.In struct where each field in the struct can be another // param. -// paramGroupedSlice -// A slice consuming a value group. This will receive all +// paramGroupedCollection +// A slice or map consuming a value group. This will receive all // values produced with a `group:".."` tag with the same name -// as a slice. +// as a slice or map. For a map, every value produced with the +// same group name MUST have a name which will form the map key. + type param interface { fmt.Stringer @@ -60,7 +62,7 @@ var ( _ param = paramSingle{} _ param = paramObject{} _ param = paramList{} - _ param = paramGroupedSlice{} + _ param = paramGroupedCollection{} ) // newParam builds a param from the given type. If the provided type is a @@ -343,7 +345,7 @@ func getParamOrder(gh *graphHolder, param param) []int { for _, provider := range providers { orders = append(orders, provider.Order(gh.s)) } - case paramGroupedSlice: + case paramGroupedCollection: // value group parameters have nodes of their own. // We can directly return that here. orders = append(orders, p.orders[gh.s]) @@ -402,7 +404,7 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) { var softGroupsQueue []paramObjectField var fields []paramObjectField for _, f := range po.Fields { - if p, ok := f.Param.(paramGroupedSlice); ok && p.Soft { + if p, ok := f.Param.(paramGroupedCollection); ok && p.Soft { softGroupsQueue = append(softGroupsQueue, f) continue } @@ -452,7 +454,7 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para case f.Tag.Get(_groupTag) != "": var err error - p, err = newParamGroupedSlice(f, c) + p, err = newParamGroupedCollection(f, c) if err != nil { return pof, err } @@ -489,13 +491,13 @@ func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { return v, nil } -// paramGroupedSlice is a param which produces a slice of values with the same +// paramGroupedCollection is a param which produces a slice or map of values with the same // group name. -type paramGroupedSlice struct { +type paramGroupedCollection struct { // Name of the group as specified in the `group:".."` tag. Group string - // Type of the slice. + // Type of the map or slice. Type reflect.Type // Soft is used to denote a soft dependency between this param and its @@ -503,15 +505,17 @@ type paramGroupedSlice struct { // provide another value requested in the graph Soft bool + isMap bool orders map[*Scope]int } -func (pt paramGroupedSlice) String() string { +func (pt paramGroupedCollection) String() string { // io.Reader[group="foo"] refers to a group of io.Readers called 'foo' return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group) + // JQTODO, different string for map } -func (pt paramGroupedSlice) DotParam() []*dot.Param { +func (pt paramGroupedCollection) DotParam() []*dot.Param { return []*dot.Param{ { Node: &dot.Node{ @@ -522,18 +526,21 @@ func (pt paramGroupedSlice) DotParam() []*dot.Param { } } -// newParamGroupedSlice builds a paramGroupedSlice from the provided type with +// newParamGroupedCollection builds a paramGroupedCollection from the provided type with // the given name. // -// The type MUST be a slice type. -func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGroupedSlice, error) { +// The type MUST be a slice or map[string]T type. +func newParamGroupedCollection(f reflect.StructField, c containerStore) (paramGroupedCollection, error) { g, err := parseGroupString(f.Tag.Get(_groupTag)) if err != nil { - return paramGroupedSlice{}, err + return paramGroupedCollection{}, err } - pg := paramGroupedSlice{ + isMap := f.Type.Kind() == reflect.Map && f.Type.Key().Kind() == reflect.String + isSlice := f.Type.Kind() == reflect.Slice + pg := paramGroupedCollection{ Group: g.Name, Type: f.Type, + isMap: isMap, orders: make(map[*Scope]int), Soft: g.Soft, } @@ -541,9 +548,9 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { - case f.Type.Kind() != reflect.Slice: + case !isMap && !isSlice: return pg, newErrInvalidInput( - fmt.Sprintf("value groups may be consumed as slices only: field %q (%v) is not a slice", f.Name, f.Type), nil) + fmt.Sprintf("value groups may be consumed as slices or string-keyed maps only: field %q (%v) is not a slice or string-keyed map", f.Name, f.Type), nil) case g.Flatten: return pg, newErrInvalidInput( fmt.Sprintf("cannot use flatten in parameter value groups: field %q (%v) specifies flatten", f.Name, f.Type), nil) @@ -561,7 +568,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped // any of the parent Scopes. In the case where there are multiple scopes that // are decorating the same type, the closest scope in effect will be replacing // any decorated value groups provided in further scopes. -func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) { +func (pt paramGroupedCollection) getDecoratedValues(c containerStore) (reflect.Value, bool) { for _, c := range c.storesToRoot() { if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok { return items, true @@ -570,13 +577,40 @@ func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, return _noValue, false } +// groupHasNamedValues checks if this group contains any named values +func (pt paramGroupedCollection) groupHasNamedValues(c containerStore) bool { + stores := c.storesToRoot() + for _, store := range stores { + kgvs := store.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + if kgv.key != "" { + return true // Found a named value + } + } + } + return false +} + +// hasSliceDecorator checks if there's a slice-type decorator for this group +func (pt paramGroupedCollection) hasSliceDecorator(c containerStore) bool { + stores := c.storesToRoot() + elementType := pt.Type.Elem() + + for _, store := range stores { + if _, ok := store.getGroupDecorator(pt.Group, elementType); ok { + return true + } + } + return false +} + // search the given container and its parents for matching group decorators // and call them to commit values. If any decorators return an error, // that error is returned immediately. If all decorators succeeds, nil is returned. // The order in which the decorators are invoked is from the top level scope to // the current scope, to account for decorators that decorate values that were // already decorated. -func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { +func (pt paramGroupedCollection) callGroupDecorators(c containerStore) error { stores := c.storesToRoot() for i := len(stores) - 1; i >= 0; i-- { c := stores[i] @@ -601,7 +635,7 @@ func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { // search the given container and its parent for matching group providers and // call them to commit values. If an error is encountered, return the number // of providers called and a non-nil error from the first provided. -func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { +func (pt paramGroupedCollection) callGroupProviders(c containerStore) (int, error) { itemCount := 0 for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) @@ -619,7 +653,7 @@ func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { return itemCount, nil } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { +func (pt paramGroupedCollection) Build(c containerStore) (reflect.Value, error) { // do not call this if we are already inside a decorator since // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) // this is safe since a value can be decorated at most once in a given scope. @@ -629,9 +663,26 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { // Check if we have decorated values if decoratedItems, ok := pt.getDecoratedValues(c); ok { + // Validate: if we found decorated values but we're trying to build a slice + // and the group has named values, this is the problematic pattern + if !pt.isMap && pt.groupHasNamedValues(c) { + return _noValue, newErrInvalidInput( + fmt.Sprintf("cannot use slice decoration for value group %q: "+ + "group contains named values, use map[string]T decorator instead", + pt.Group), nil) + } return decoratedItems, nil } + // Check if we have a slice decorator for a group with named values - this is always wrong + // Only block if there's actually a decorator AND named values + if !pt.isMap && pt.hasSliceDecorator(c) && pt.groupHasNamedValues(c) { + return _noValue, newErrInvalidInput( + fmt.Sprintf("cannot use slice decoration for value group %q: "+ + "group contains named values, use map[string]T decorator instead", + pt.Group), nil) + } + // If we do not have any decorated values and the group isn't soft, // find the providers and call them. itemCount := 0 @@ -644,9 +695,28 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } stores := c.storesToRoot() + if pt.isMap { + result := reflect.MakeMapWithSize(pt.Type, itemCount) + for _, c := range stores { + kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + if kgv.key == "" { + return _noValue, newErrInvalidInput( + fmt.Sprintf("every entry in a map value groups must have a name, group \"%v\" is missing a name", pt.Group), + nil, + ) + } + result.SetMapIndex(reflect.ValueOf(kgv.key), kgv.value) + } + } + return result, nil + } result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { - result = reflect.Append(result, c.getValueGroup(pt.Group, pt.Type.Elem())...) + kgvs := c.getValueGroup(pt.Group, pt.Type.Elem()) + for _, kgv := range kgvs { + result = reflect.Append(result, kgv.value) + } } return result, nil } diff --git a/param_test.go b/param_test.go index 507529aa..3bffa914 100644 --- a/param_test.go +++ b/param_test.go @@ -177,21 +177,31 @@ func TestParamObjectFailure(t *testing.T) { }) } -func TestParamGroupSliceErrors(t *testing.T) { +func TestParamGroupCollectionErrors(t *testing.T) { tests := []struct { desc string shape interface{} wantErr string }{ { - desc: "non-slice type are disallowed", + desc: "non-slice or string-keyed map type are disallowed (slice)", shape: struct { In Foo string `group:"foo"` }{}, - wantErr: "value groups may be consumed as slices only: " + - `field "Foo" (string) is not a slice`, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (string) is not a slice or string-keyed map`, + }, + { + desc: "non-slice or string-keyed map type are disallowed (string-keyed map)", + shape: struct { + In + + Foo map[int]int `group:"foo"` + }{}, + wantErr: "value groups may be consumed as slices or string-keyed maps only: " + + `field "Foo" (map[int]int) is not a slice or string-keyed map`, }, { desc: "cannot provide name for a group", diff --git a/provide.go b/provide.go index 9e47b6db..86ab5b1d 100644 --- a/provide.go +++ b/provide.go @@ -48,12 +48,6 @@ type provideOptions struct { } func (o *provideOptions) Validate() error { - if len(o.Group) > 0 { - if len(o.Name) > 0 { - return newErrInvalidInput( - fmt.Sprintf("cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group), nil) - } - } // Names must be representable inside a backquoted string. The only // limitation for raw string literals as per diff --git a/result.go b/result.go index 369cd218..24eb7d8d 100644 --- a/result.go +++ b/result.go @@ -66,7 +66,7 @@ type resultOptions struct { } // newResult builds a result from the given type. -func newResult(t reflect.Type, opts resultOptions) (result, error) { +func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) { switch { case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType): return nil, newErrInvalidInput(fmt.Sprintf( @@ -81,13 +81,13 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) { case t.Kind() == reflect.Ptr && IsOut(t.Elem()): return nil, newErrInvalidInput(fmt.Sprintf( "cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil) - case len(opts.Group) > 0: + case len(opts.Group) > 0 && !noGroup: g, err := parseGroupString(opts.Group) if err != nil { return nil, newErrInvalidInput( fmt.Sprintf("cannot parse group %q", opts.Group), err) } - rg := resultGrouped{Type: t, Group: g.Name, Flatten: g.Flatten} + rg := resultGrouped{Type: t, Key: opts.Name, Group: g.Name, Flatten: g.Flatten} if len(opts.As) > 0 { var asTypes []reflect.Type for _, as := range opts.As { @@ -176,7 +176,9 @@ func walkResult(r result, v resultVisitor) { w := v for _, f := range res.Fields { if v := w.AnnotateWithField(f); v != nil { - walkResult(f.Result, v) + for _, r := range f.Results { + walkResult(r, v) + } } } case resultList: @@ -200,7 +202,7 @@ type resultList struct { // For each item at index i returned by the constructor, resultIndexes[i] // is the index in .Results for the corresponding result object. // resultIndexes[i] is -1 for errors returned by constructors. - resultIndexes []int + resultIndexes [][]int } func (rl resultList) DotResult() []*dot.Result { @@ -216,25 +218,45 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { rl := resultList{ ctype: ctype, Results: make([]result, 0, numOut), - resultIndexes: make([]int, numOut), + resultIndexes: make([][]int, numOut), } resultIdx := 0 for i := 0; i < numOut; i++ { t := ctype.Out(i) if isError(t) { - rl.resultIndexes[i] = -1 + rl.resultIndexes[i] = append(rl.resultIndexes[i], -1) continue } - r, err := newResult(t, opts) - if err != nil { - return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) + addResult := func(nogroup bool) error { + r, err := newResult(t, opts, nogroup) + if err != nil { + return newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err) + } + + rl.Results = append(rl.Results, r) + rl.resultIndexes[i] = append(rl.resultIndexes[i], resultIdx) + resultIdx++ + return nil + } + + // special case, its added as a group and a name using options alone + if len(opts.Name) > 0 && len(opts.Group) > 0 && !IsOut(t) { + // add as a group + if err := addResult(false); err != nil { + return rl, err + } + // add as single + err := addResult(true) + return rl, err + } + + // add as normal + if err := addResult(false); err != nil { + return rl, err } - rl.Results = append(rl.Results, r) - rl.resultIndexes[i] = resultIdx - resultIdx++ } return rl, nil @@ -246,8 +268,14 @@ func (resultList) Extract(containerWriter, bool, reflect.Value) { func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error { for i, v := range values { - if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { - rl.Results[resultIdx].Extract(cw, decorated, v) + isNonErrorResult := false + for _, resultIdx := range rl.resultIndexes[i] { + if resultIdx >= 0 { + rl.Results[resultIdx].Extract(cw, decorated, v) + isNonErrorResult = true + } + } + if isNonErrorResult { continue } @@ -384,7 +412,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) + for _, r := range f.Results { + r.Extract(cw, decorated, v.Field(f.FieldIndex)) + } } } @@ -399,12 +429,16 @@ type resultObjectField struct { // map to results. FieldIndex int - // Result produced by this field. - Result result + // Results produced by this field. + Results []result } func (rof resultObjectField) DotResult() []*dot.Result { - return rof.Result.DotResult() + results := make([]*dot.Result, 0, len(rof.Results)) + for _, r := range rof.Results { + results = append(results, r.DotResult()...) + } + return results } // newResultObjectField(i, f, opts) builds a resultObjectField from the field @@ -414,7 +448,11 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r FieldName: f.Name, FieldIndex: idx, } - + name := f.Tag.Get(_nameTag) + if len(name) > 0 { + // can modify in-place because options are passed-by-value. + opts.Name = name + } var r result switch { case f.PkgPath != "": @@ -427,20 +465,21 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r if err != nil { return rof, err } + rof.Results = append(rof.Results, r) + if len(name) == 0 { + break + } + fallthrough default: var err error - if name := f.Tag.Get(_nameTag); len(name) > 0 { - // can modify in-place because options are passed-by-value. - opts.Name = name - } - r, err = newResult(f.Type, opts) + r, err = newResult(f.Type, opts, false) if err != nil { return rof, err } + rof.Results = append(rof.Results, r) } - rof.Result = r return rof, nil } @@ -452,6 +491,9 @@ type resultGrouped struct { // Name of the group as specified in the `group:".."` tag. Group string + // Key if a name tag or option was provided, for populating maps + Key string + // Type of value produced. Type reflect.Type @@ -488,12 +530,13 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { if err != nil { return resultGrouped{}, err } + name := f.Tag.Get(_nameTag) rg := resultGrouped{ Group: g.Name, + Key: name, Flatten: g.Flatten, Type: f.Type, } - name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { case g.Flatten && f.Type.Kind() != reflect.Slice: @@ -502,9 +545,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { case g.Soft: return rg, newErrInvalidInput(fmt.Sprintf( "cannot use soft with result value groups: soft was used with group %q", rg.Group), nil) - case name != "": - return rg, newErrInvalidInput(fmt.Sprintf( - "cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil) case optional: return rg, newErrInvalidInput("value groups cannot be optional", nil) } @@ -518,18 +558,19 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Value) { // Decorated values are always flattened. if !decorated && !rt.Flatten { - cw.submitGroupedValue(rt.Group, rt.Type, v) + cw.submitGroupedValue(rt.Group, rt.Key, rt.Type, v) for _, asType := range rt.As { - cw.submitGroupedValue(rt.Group, asType, v) + cw.submitGroupedValue(rt.Group, rt.Key, asType, v) } return } if decorated { - cw.submitDecoratedGroupedValue(rt.Group, rt.Type, v) + cw.submitDecoratedGroupedValue(rt.Group, rt.Key, rt.Type, v) return } + // it's not possible to provide a key for the flattening case for i := 0; i < v.Len(); i++ { - cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i)) + cw.submitGroupedValue(rt.Group, "", rt.Type, v.Index(i)) } } diff --git a/result_test.go b/result_test.go index c19db20d..974e9d8c 100644 --- a/result_test.go +++ b/result_test.go @@ -108,7 +108,7 @@ func TestNewResultErrors(t *testing.T) { for _, tt := range tests { give := reflect.TypeOf(tt.give) t.Run(fmt.Sprint(give), func(t *testing.T) { - _, err := newResult(give, resultOptions{}) + _, err := newResult(give, resultOptions{}, false) require.Error(t, err) assert.Contains(t, err.Error(), tt.err) }) @@ -139,12 +139,12 @@ func TestNewResultObject(t *testing.T) { { FieldName: "Reader", FieldIndex: 1, - Result: resultSingle{Type: typeOfReader}, + Results: []result{resultSingle{Type: typeOfReader}}, }, { FieldName: "Writer", FieldIndex: 2, - Result: resultSingle{Type: typeOfWriter}, + Results: []result{resultSingle{Type: typeOfWriter}}, }, }, }, @@ -160,12 +160,12 @@ func TestNewResultObject(t *testing.T) { { FieldName: "A", FieldIndex: 1, - Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, + Results: []result{resultSingle{Name: "stream-a", Type: typeOfWriter}}, }, { FieldName: "B", FieldIndex: 2, - Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, + Results: []result{resultSingle{Name: "stream-b", Type: typeOfWriter}}, }, }, }, @@ -180,7 +180,25 @@ func TestNewResultObject(t *testing.T) { { FieldName: "Writer", FieldIndex: 1, - Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + Results: []result{resultGrouped{Group: "writers", Type: typeOfWriter}}, + }, + }, + }, + { + desc: "group and name tag", + give: struct { + Out + + Writer io.Writer `name:"writer1" group:"writers"` + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Writer", + FieldIndex: 1, + Results: []result{ + resultGrouped{Group: "writers", Key: "writer1", Type: typeOfWriter}, + resultSingle{Name: "writer1", Type: typeOfWriter}, + }, }, }, }, @@ -229,16 +247,6 @@ func TestNewResultObjectErrors(t *testing.T) { }{}, err: `bad field "Nested"`, }, - { - desc: "group with name should fail", - give: struct { - Out - - Foo string `group:"foo" name:"bar"` - }{}, - err: "cannot use named values with value groups: " + - `name:"bar" provided with group:"foo"`, - }, { desc: "group marked as optional", give: struct { @@ -414,31 +422,31 @@ func TestWalkResult(t *testing.T) { { AnnotateWithField: &ro.Fields[0], Return: fakeResultVisits{ - {Visit: ro.Fields[0].Result}, + {Visit: ro.Fields[0].Results[0]}, }, }, { AnnotateWithField: &ro.Fields[1], Return: fakeResultVisits{ - {Visit: ro.Fields[1].Result}, + {Visit: ro.Fields[1].Results[0]}, }, }, { AnnotateWithField: &ro.Fields[2], Return: fakeResultVisits{ { - Visit: ro.Fields[2].Result, + Visit: ro.Fields[2].Results[0], Return: fakeResultVisits{ { - AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[0], + AnnotateWithField: &ro.Fields[2].Results[0].(resultObject).Fields[0], Return: fakeResultVisits{ - {Visit: ro.Fields[2].Result.(resultObject).Fields[0].Result}, + {Visit: ro.Fields[2].Results[0].(resultObject).Fields[0].Results[0]}, }, }, { - AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[1], + AnnotateWithField: &ro.Fields[2].Results[0].(resultObject).Fields[1], Return: fakeResultVisits{ - {Visit: ro.Fields[2].Result.(resultObject).Fields[1].Result}, + {Visit: ro.Fields[2].Results[0].(resultObject).Fields[1].Results[0]}, }, }, }, diff --git a/scope.go b/scope.go index b03b5087..357b7220 100644 --- a/scope.go +++ b/scope.go @@ -37,6 +37,11 @@ type ScopeOption interface { noScopeOption() // yet } +type keyedGroupValue struct { + key string + value reflect.Value +} + // Scope is a scoped DAG of types and their dependencies. // A Scope may also have one or more child Scopes that inherit // from it. @@ -63,10 +68,10 @@ type Scope struct { values map[key]reflect.Value // Values groups that generated directly in the Scope. - groups map[key][]reflect.Value + groups map[key][]keyedGroupValue // Values groups that generated via decoraters in the Scope. - decoratedGroups map[key]reflect.Value + decoratedGroups map[key]keyedGroupValue // Source of randomness. rand *rand.Rand @@ -103,8 +108,8 @@ func newScope() *Scope { decorators: make(map[key]*decoratorNode), values: make(map[key]reflect.Value), decoratedValues: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), - decoratedGroups: make(map[key]reflect.Value), + groups: make(map[key][]keyedGroupValue), + decoratedGroups: make(map[key]keyedGroupValue), invokerFn: defaultInvoker, rand: rand.New(rand.NewSource(time.Now().UnixNano())), clockSrc: digclock.System, @@ -202,7 +207,7 @@ func (s *Scope) setDecoratedValue(name string, t reflect.Type, v reflect.Value) s.decoratedValues[key{name: name, t: t}] = v } -func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { +func (s *Scope) getValueGroup(name string, t reflect.Type) []keyedGroupValue { items := s.groups[key{group: name, t: t}] // shuffle the list so users don't rely on the ordering of grouped values return shuffledCopy(s.rand, items) @@ -210,17 +215,17 @@ func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { func (s *Scope) getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) { items, ok := s.decoratedGroups[key{group: name, t: t}] - return items, ok + return items.value, ok } -func (s *Scope) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { +func (s *Scope) submitGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} - s.groups[k] = append(s.groups[k], v) + s.groups[k] = append(s.groups[k], keyedGroupValue{key: mapKey, value: v}) } -func (s *Scope) submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) { +func (s *Scope) submitDecoratedGroupedValue(name, mapKey string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} - s.decoratedGroups[k] = v + s.decoratedGroups[k] = keyedGroupValue{key: mapKey, value: v} } func (s *Scope) getValueProviders(name string, t reflect.Type) []provider { @@ -326,9 +331,9 @@ func (s *Scope) String() string { for k, v := range s.values { fmt.Fprintln(b, "\t", k, "=>", v) } - for k, vs := range s.groups { - for _, v := range vs { - fmt.Fprintln(b, "\t", k, "=>", v) + for k, kgvs := range s.groups { + for _, kgv := range kgvs { + fmt.Fprintln(b, "\t", k, "=>", kgv.value) } } fmt.Fprintln(b, "}")