Skip to content

Commit 660e806

Browse files
committed
Support simultanous name and group tags
As per Dig issue: #380 In order to support Fx feature requests uber-go/fx#998 uber-go/fx#1036 We need to be able to drop the restriction, both in terms of options dig.Name and dig.Group and dig.Out struct annotations on `name` and `group` being mutually exclusive. In a future PR, this can then be exploited to populate value group maps where the 'name' tag becomes the key of a map[string][T]
1 parent 7f9f0b8 commit 660e806

File tree

5 files changed

+217
-74
lines changed

5 files changed

+217
-74
lines changed

decorate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ func findResultKeys(r resultList) ([]key, error) {
288288
keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group})
289289
case resultObject:
290290
for _, f := range innerResult.Fields {
291-
q = append(q, f.Result)
291+
q = append(q, f.Results...)
292292
}
293293
case resultList:
294294
q = append(q, innerResult.Results...)

dig_test.go

Lines changed: 119 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,53 @@ func TestEndToEndSuccess(t *testing.T) {
749749
assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match")
750750
})
751751

752+
t.Run("multiple As with Group and Name", func(t *testing.T) {
753+
c := digtest.New(t)
754+
expectedNames := []string{"inst1", "inst2"}
755+
expectedStrs := []string{"foo", "bar"}
756+
for i, s := range expectedStrs {
757+
s := s
758+
c.RequireProvide(func() *bytes.Buffer {
759+
return bytes.NewBufferString(s)
760+
}, dig.Group("buffs"), dig.Name(expectedNames[i]),
761+
dig.As(new(io.Reader), new(io.Writer)))
762+
}
763+
764+
type in struct {
765+
dig.In
766+
767+
Reader1 io.Reader `name:"inst1"`
768+
Reader2 io.Reader `name:"inst2"`
769+
Readers []io.Reader `group:"buffs"`
770+
Writers []io.Writer `group:"buffs"`
771+
}
772+
773+
var actualStrs []string
774+
var actualStrsName []string
775+
776+
c.RequireInvoke(func(got in) {
777+
require.Len(t, got.Readers, 2)
778+
buf := make([]byte, 3)
779+
for i, r := range got.Readers {
780+
_, err := r.Read(buf)
781+
require.NoError(t, err)
782+
actualStrs = append(actualStrs, string(buf))
783+
// put the text back
784+
got.Writers[i].Write(buf)
785+
}
786+
_, err := got.Reader1.Read(buf)
787+
require.NoError(t, err)
788+
actualStrsName = append(actualStrsName, string(buf))
789+
_, err = got.Reader2.Read(buf)
790+
require.NoError(t, err)
791+
actualStrsName = append(actualStrsName, string(buf))
792+
require.Len(t, got.Writers, 2)
793+
})
794+
795+
assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match")
796+
assert.ElementsMatch(t, actualStrsName, expectedStrs, "names: list of strings provided must match")
797+
})
798+
752799
t.Run("As same interface", func(t *testing.T) {
753800
c := digtest.New(t)
754801
c.RequireProvide(func() io.Reader {
@@ -1098,6 +1145,48 @@ func TestGroups(t *testing.T) {
10981145
})
10991146
})
11001147

1148+
t.Run("values are provided; coexist with name", func(t *testing.T) {
1149+
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))
1150+
1151+
type out struct {
1152+
dig.Out
1153+
1154+
Value int `group:"val"`
1155+
}
1156+
1157+
type out2 struct {
1158+
dig.Out
1159+
1160+
Value int `name:"inst1" group:"val"`
1161+
}
1162+
1163+
provide := func(i int) {
1164+
c.RequireProvide(func() out {
1165+
return out{Value: i}
1166+
})
1167+
}
1168+
1169+
provide(1)
1170+
provide(2)
1171+
provide(3)
1172+
1173+
c.RequireProvide(func() out2 {
1174+
return out2{Value: 4}
1175+
})
1176+
1177+
type in struct {
1178+
dig.In
1179+
1180+
SingleValue int `name:"inst1"`
1181+
Values []int `group:"val"`
1182+
}
1183+
1184+
c.RequireInvoke(func(i in) {
1185+
assert.Equal(t, []int{1, 2, 3, 4}, i.Values)
1186+
assert.Equal(t, 4, i.SingleValue)
1187+
})
1188+
})
1189+
11011190
t.Run("groups are provided via option", func(t *testing.T) {
11021191
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))
11031192

@@ -1122,6 +1211,36 @@ func TestGroups(t *testing.T) {
11221211
})
11231212
})
11241213

1214+
t.Run("groups are provided via option; coexist with name", func(t *testing.T) {
1215+
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))
1216+
1217+
provide := func(i int) {
1218+
c.RequireProvide(func() int {
1219+
return i
1220+
}, dig.Group("val"))
1221+
}
1222+
1223+
provide(1)
1224+
provide(2)
1225+
provide(3)
1226+
1227+
c.RequireProvide(func() int {
1228+
return 4
1229+
}, dig.Group("val"), dig.Name("inst1"))
1230+
1231+
type in struct {
1232+
dig.In
1233+
1234+
SingleValue int `name:"inst1"`
1235+
Values []int `group:"val"`
1236+
}
1237+
1238+
c.RequireInvoke(func(i in) {
1239+
assert.Equal(t, []int{1, 2, 3, 4}, i.Values)
1240+
assert.Equal(t, 4, i.SingleValue)
1241+
})
1242+
})
1243+
11251244
t.Run("different types may be grouped", func(t *testing.T) {
11261245
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))
11271246

@@ -1998,21 +2117,6 @@ func TestAsExpectingOriginalType(t *testing.T) {
19982117
})
19992118
}
20002119

2001-
func TestProvideIncompatibleOptions(t *testing.T) {
2002-
t.Parallel()
2003-
2004-
t.Run("group and name", func(t *testing.T) {
2005-
c := digtest.New(t)
2006-
err := c.Provide(func() io.Reader {
2007-
t.Fatal("this function must not be called")
2008-
return nil
2009-
}, dig.Group("foo"), dig.Name("bar"))
2010-
require.Error(t, err)
2011-
assert.Contains(t, err.Error(), "cannot use named values with value groups: "+
2012-
`name:"bar" provided with group:"foo"`)
2013-
})
2014-
}
2015-
20162120
type testStruct struct{}
20172121

20182122
func (testStruct) TestMethod(x int) float64 { return float64(x) }

provide.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ type provideOptions struct {
4646
}
4747

4848
func (o *provideOptions) Validate() error {
49-
if len(o.Group) > 0 {
50-
if len(o.Name) > 0 {
51-
return newErrInvalidInput(
52-
fmt.Sprintf("cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group), nil)
53-
}
54-
}
5549

5650
// Names must be representable inside a backquoted string. The only
5751
// limitation for raw string literals as per

result.go

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ type resultOptions struct {
6666
}
6767

6868
// newResult builds a result from the given type.
69-
func newResult(t reflect.Type, opts resultOptions) (result, error) {
69+
func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) {
7070
switch {
7171
case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType):
7272
return nil, newErrInvalidInput(fmt.Sprintf(
@@ -81,7 +81,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) {
8181
case t.Kind() == reflect.Ptr && IsOut(t.Elem()):
8282
return nil, newErrInvalidInput(fmt.Sprintf(
8383
"cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil)
84-
case len(opts.Group) > 0:
84+
case len(opts.Group) > 0 && !noGroup:
8585
g, err := parseGroupString(opts.Group)
8686
if err != nil {
8787
return nil, newErrInvalidInput(
@@ -176,7 +176,9 @@ func walkResult(r result, v resultVisitor) {
176176
w := v
177177
for _, f := range res.Fields {
178178
if v := w.AnnotateWithField(f); v != nil {
179-
walkResult(f.Result, v)
179+
for _, r := range f.Results {
180+
walkResult(r, v)
181+
}
180182
}
181183
}
182184
case resultList:
@@ -200,7 +202,7 @@ type resultList struct {
200202
// For each item at index i returned by the constructor, resultIndexes[i]
201203
// is the index in .Results for the corresponding result object.
202204
// resultIndexes[i] is -1 for errors returned by constructors.
203-
resultIndexes []int
205+
resultIndexes [][]int
204206
}
205207

206208
func (rl resultList) DotResult() []*dot.Result {
@@ -216,25 +218,47 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) {
216218
rl := resultList{
217219
ctype: ctype,
218220
Results: make([]result, 0, numOut),
219-
resultIndexes: make([]int, numOut),
221+
resultIndexes: make([][]int, numOut),
220222
}
221223

222224
resultIdx := 0
223225
for i := 0; i < numOut; i++ {
224226
t := ctype.Out(i)
225227
if isError(t) {
226-
rl.resultIndexes[i] = -1
228+
rl.resultIndexes[i] = append(rl.resultIndexes[i], -1)
227229
continue
228230
}
229231

230-
r, err := newResult(t, opts)
231-
if err != nil {
232-
return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err)
232+
addResult := func(nogroup bool) error {
233+
r, err := newResult(t, opts, nogroup)
234+
if err != nil {
235+
return newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err)
236+
}
237+
238+
rl.Results = append(rl.Results, r)
239+
rl.resultIndexes[i] = append(rl.resultIndexes[i], resultIdx)
240+
resultIdx++
241+
return nil
242+
}
243+
244+
// special case, its added as a group and a name using options alone
245+
if len(opts.Name) > 0 && len(opts.Group) > 0 && !IsOut(t) {
246+
// add as a group
247+
if err := addResult(false); err != nil {
248+
return rl, err
249+
}
250+
// add as single
251+
if err := addResult(true); err != nil {
252+
return rl, err
253+
}
254+
return rl, nil
255+
}
256+
257+
// add as normal
258+
if err := addResult(false); err != nil {
259+
return rl, err
233260
}
234261

235-
rl.Results = append(rl.Results, r)
236-
rl.resultIndexes[i] = resultIdx
237-
resultIdx++
238262
}
239263

240264
return rl, nil
@@ -246,8 +270,14 @@ func (resultList) Extract(containerWriter, bool, reflect.Value) {
246270

247271
func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error {
248272
for i, v := range values {
249-
if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 {
250-
rl.Results[resultIdx].Extract(cw, decorated, v)
273+
isNonErrorResult := false
274+
for _, resultIdx := range rl.resultIndexes[i] {
275+
if resultIdx >= 0 {
276+
rl.Results[resultIdx].Extract(cw, decorated, v)
277+
isNonErrorResult = true
278+
}
279+
}
280+
if isNonErrorResult {
251281
continue
252282
}
253283

@@ -384,7 +414,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) {
384414

385415
func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) {
386416
for _, f := range ro.Fields {
387-
f.Result.Extract(cw, decorated, v.Field(f.FieldIndex))
417+
for _, r := range f.Results {
418+
r.Extract(cw, decorated, v.Field(f.FieldIndex))
419+
}
388420
}
389421
}
390422

@@ -399,12 +431,16 @@ type resultObjectField struct {
399431
// map to results.
400432
FieldIndex int
401433

402-
// Result produced by this field.
403-
Result result
434+
// Results produced by this field.
435+
Results []result
404436
}
405437

406438
func (rof resultObjectField) DotResult() []*dot.Result {
407-
return rof.Result.DotResult()
439+
results := make([]*dot.Result, 0, len(rof.Results))
440+
for _, r := range rof.Results {
441+
results = append(results, r.DotResult()...)
442+
}
443+
return results
408444
}
409445

410446
// newResultObjectField(i, f, opts) builds a resultObjectField from the field
@@ -414,7 +450,11 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r
414450
FieldName: f.Name,
415451
FieldIndex: idx,
416452
}
417-
453+
name := f.Tag.Get(_nameTag)
454+
if len(name) > 0 {
455+
// can modify in-place because options are passed-by-value.
456+
opts.Name = name
457+
}
418458
var r result
419459
switch {
420460
case f.PkgPath != "":
@@ -427,20 +467,21 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r
427467
if err != nil {
428468
return rof, err
429469
}
470+
rof.Results = append(rof.Results, r)
471+
if len(name) == 0 {
472+
break
473+
}
474+
fallthrough
430475

431476
default:
432477
var err error
433-
if name := f.Tag.Get(_nameTag); len(name) > 0 {
434-
// can modify in-place because options are passed-by-value.
435-
opts.Name = name
436-
}
437-
r, err = newResult(f.Type, opts)
478+
r, err = newResult(f.Type, opts, false)
438479
if err != nil {
439480
return rof, err
440481
}
482+
rof.Results = append(rof.Results, r)
441483
}
442484

443-
rof.Result = r
444485
return rof, nil
445486
}
446487

@@ -493,7 +534,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) {
493534
Flatten: g.Flatten,
494535
Type: f.Type,
495536
}
496-
name := f.Tag.Get(_nameTag)
497537
optional, _ := isFieldOptional(f)
498538
switch {
499539
case g.Flatten && f.Type.Kind() != reflect.Slice:
@@ -502,9 +542,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) {
502542
case g.Soft:
503543
return rg, newErrInvalidInput(fmt.Sprintf(
504544
"cannot use soft with result value groups: soft was used with group %q", rg.Group), nil)
505-
case name != "":
506-
return rg, newErrInvalidInput(fmt.Sprintf(
507-
"cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil)
508545
case optional:
509546
return rg, newErrInvalidInput("value groups cannot be optional", nil)
510547
}

0 commit comments

Comments
 (0)