Skip to content

Commit 2adef8a

Browse files
committed
Merge pull request #294 from russmatney/master
feat(encoding): allow for referential arrays
2 parents 9a23249 + eebb883 commit 2adef8a

File tree

2 files changed

+58
-21
lines changed

2 files changed

+58
-21
lines changed

encoding/encoder_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,26 @@ func TestReferenceFieldInvalid(t *testing.T) {
324324
t.Errorf("expected non-nil error but got nil")
325325
}
326326
}
327+
328+
type RefE struct {
329+
ID string `gorethink:"id,omitempty"`
330+
FIDs *[]RefF `gorethink:"f_ids,reference" gorethink_ref:"id"`
331+
}
332+
333+
type RefF struct {
334+
ID string `gorethink:"id,omitempty"`
335+
Name string `gorethink:"name"`
336+
}
337+
338+
func TestReferenceFieldArray(t *testing.T) {
339+
input := RefE{"1", &[]RefF{RefF{"2", "Name2"}, RefF{"3", "Name3"}}}
340+
want := map[string]interface{}{"id": "1", "f_ids": []string{"2", "3"}}
341+
342+
out, err := Encode(input)
343+
if err != nil {
344+
t.Errorf("got error %v, expected nil", err)
345+
}
346+
if !jsonEqual(out, want) {
347+
t.Errorf("got %q, want %q", out, want)
348+
}
349+
}

encoding/encoder_types.go

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ type structEncoder struct {
148148

149149
func (se *structEncoder) encode(v reflect.Value) interface{} {
150150
m := make(map[string]interface{})
151-
152151
for i, f := range se.fields {
153152
fv := fieldByIndex(v, f.index)
154153
if !fv.IsValid() || f.omitEmpty && se.isEmptyValue(fv) {
@@ -159,27 +158,8 @@ func (se *structEncoder) encode(v reflect.Value) interface{} {
159158

160159
// If this field is a referenced field then attempt to extract the value.
161160
if f.reference {
162-
refName := f.name
163-
if f.refName != "" {
164-
refName = f.refName
165-
}
166-
167-
// referenced fields can only handle maps so return an error if the
168-
// encoded field is of a different type
169-
m, ok := encField.(map[string]interface{})
170-
if !ok {
171-
err := fmt.Errorf("Error referencing field %s in %s, expected object but got %t", refName, f.name, encField)
172-
panic(&MarshalerError{v.Type(), err})
173-
}
174-
175-
refVal, ok := m[refName]
176-
if !ok {
177-
err := fmt.Errorf("Error referencing field %s in %s, could not find referenced field", refName, f.name)
178-
panic(&MarshalerError{v.Type(), err})
179-
}
180-
181161
// Override the encoded field with the referenced field
182-
encField = refVal
162+
encField = getReferenceField(f, v, encField)
183163
}
184164

185165
m[f.name] = encField
@@ -188,6 +168,40 @@ func (se *structEncoder) encode(v reflect.Value) interface{} {
188168
return m
189169
}
190170

171+
func getReferenceField(f field, v reflect.Value, encField interface{}) interface{} {
172+
refName := f.name
173+
if f.refName != "" {
174+
refName = f.refName
175+
}
176+
177+
encFields, isArray := encField.([]interface{})
178+
if isArray {
179+
refVals := make([]interface{}, len(encFields))
180+
for i, e := range encFields {
181+
refVals[i] = extractValue(e, v, f.name, refName)
182+
}
183+
return refVals
184+
}
185+
refVal := extractValue(encField, v, f.name, refName)
186+
return refVal
187+
}
188+
189+
func extractValue(encField interface{}, v reflect.Value, name string, refName string) interface{} {
190+
// referenced fields can only handle maps so return an error if the
191+
// encoded field is of a different type
192+
m, ok := encField.(map[string]interface{})
193+
if !ok {
194+
err := fmt.Errorf("Error refing field %s in %s, expected object but got %t", refName, name, encField)
195+
panic(&MarshalerError{v.Type(), err})
196+
}
197+
refVal, ok := m[refName]
198+
if !ok {
199+
err := fmt.Errorf("Error refing field %s in %s, could not find referenced field", refName, name)
200+
panic(&MarshalerError{v.Type(), err})
201+
}
202+
return refVal
203+
}
204+
191205
func (se *structEncoder) isEmptyValue(v reflect.Value) bool {
192206
if v.Type() == timeType {
193207
return v.Interface().(time.Time) == time.Time{}

0 commit comments

Comments
 (0)