Skip to content

Commit dbdf19f

Browse files
authored
Merge pull request #1605 from MLedIum/container-types
Supported of list, set and struct for unmarshall using `sugar.Unmarshall...`
2 parents 0e361e6 + 14bb201 commit dbdf19f

File tree

4 files changed

+293
-0
lines changed

4 files changed

+293
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Supported of list, set and struct for unmarshall using `sugar.Unmarshall...`
2+
13
## v3.95.6
24
* Fixed panic on span reporting in `xsql/Tx`
35

internal/value/value.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,33 @@ func (v *listValue) castTo(dst any) error {
12841284
case *driver.Value:
12851285
*dstValue = v
12861286

1287+
return nil
1288+
case interface{}:
1289+
ptr := reflect.ValueOf(dstValue)
1290+
1291+
inner := reflect.Indirect(ptr)
1292+
if inner.Kind() != reflect.Slice && inner.Kind() != reflect.Array {
1293+
return xerrors.WithStackTrace(fmt.Errorf(
1294+
"%w '%s(%+v)' to '%T' destination",
1295+
ErrCannotCast, v.Type().Yql(), v, dstValue,
1296+
))
1297+
}
1298+
1299+
targetType := inner.Type().Elem()
1300+
valueInner := reflect.ValueOf(v.ListItems())
1301+
1302+
newSlice := reflect.MakeSlice(reflect.SliceOf(targetType), valueInner.Len(), valueInner.Cap())
1303+
inner.Set(newSlice)
1304+
1305+
for i, item := range v.ListItems() {
1306+
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1307+
return xerrors.WithStackTrace(fmt.Errorf(
1308+
"%w '%s(%+v)' to '%T' destination",
1309+
ErrCannotCast, v.Type().Yql(), v, dstValue,
1310+
))
1311+
}
1312+
}
1313+
12871314
return nil
12881315
default:
12891316
return xerrors.WithStackTrace(fmt.Errorf(
@@ -1391,6 +1418,33 @@ func (v *setValue) castTo(dst any) error {
13911418
case *driver.Value:
13921419
*dstValue = v
13931420

1421+
return nil
1422+
case interface{}:
1423+
ptr := reflect.ValueOf(dstValue)
1424+
1425+
inner := reflect.Indirect(ptr)
1426+
if inner.Kind() != reflect.Slice && inner.Kind() != reflect.Array {
1427+
return xerrors.WithStackTrace(fmt.Errorf(
1428+
"%w '%s(%+v)' to '%T' destination",
1429+
ErrCannotCast, v.Type().Yql(), v, dstValue,
1430+
))
1431+
}
1432+
1433+
targetType := inner.Type().Elem()
1434+
valueInner := reflect.ValueOf(v.items)
1435+
1436+
newSlice := reflect.MakeSlice(reflect.SliceOf(targetType), valueInner.Len(), valueInner.Cap())
1437+
inner.Set(newSlice)
1438+
1439+
for i, item := range v.items {
1440+
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1441+
return xerrors.WithStackTrace(fmt.Errorf(
1442+
"%w '%s(%+v)' to '%T' destination",
1443+
ErrCannotCast, v.Type().Yql(), v, dstValue,
1444+
))
1445+
}
1446+
}
1447+
13941448
return nil
13951449
default:
13961450
return xerrors.WithStackTrace(fmt.Errorf(
@@ -1574,6 +1628,27 @@ func (v *structValue) castTo(dst any) error {
15741628
case *driver.Value:
15751629
*dstValue = v
15761630

1631+
return nil
1632+
case interface{}:
1633+
ptr := reflect.ValueOf(dst)
1634+
1635+
inner := reflect.Indirect(ptr)
1636+
if inner.Kind() != reflect.Struct {
1637+
return xerrors.WithStackTrace(fmt.Errorf(
1638+
"%w '%s(%+v)' to '%T' destination",
1639+
ErrCannotCast, v.Type().Yql(), v, dstValue,
1640+
))
1641+
}
1642+
1643+
for i, field := range v.fields {
1644+
if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil {
1645+
return xerrors.WithStackTrace(fmt.Errorf(
1646+
"scan error on struct field name '%s': %w",
1647+
field.Name, err,
1648+
))
1649+
}
1650+
}
1651+
15771652
return nil
15781653
default:
15791654
return xerrors.WithStackTrace(fmt.Errorf(

internal/value/value_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,145 @@ func TestCastNumbers(t *testing.T) {
11421142
}
11431143
}
11441144

1145+
func TestCastList(t *testing.T) {
1146+
for _, tt := range []struct {
1147+
v Value
1148+
dst interface{}
1149+
result interface{}
1150+
error bool
1151+
}{
1152+
{
1153+
v: ListValue(Int32Value(12), Int32Value(21), Int32Value(56)),
1154+
dst: func(v []int64) *[]int64 { return &v }([]int64{}),
1155+
result: func(v []int64) *[]int64 { return &v }([]int64{12, 21, 56}),
1156+
error: false,
1157+
},
1158+
{
1159+
v: ListValue(Int32Value(12), Int32Value(21), Int32Value(56)),
1160+
dst: func(v []int64) *[]int64 { return &v }([]int64{17}),
1161+
result: func(v []int64) *[]int64 { return &v }([]int64{12, 21, 56}),
1162+
error: false,
1163+
},
1164+
{
1165+
v: ListValue(BytesValue([]byte("test")), BytesValue([]byte("test2"))),
1166+
dst: func(v []string) *[]string { return &v }([]string{}),
1167+
result: func(v []string) *[]string { return &v }([]string{"test", "test2"}),
1168+
error: false,
1169+
},
1170+
{
1171+
v: ListValue(BytesValue([]byte("test")), BytesValue([]byte("test2"))),
1172+
dst: func(v []string) *[]string { return &v }([]string{"list"}),
1173+
result: func(v []string) *[]string { return &v }([]string{"test", "test2"}),
1174+
error: false,
1175+
},
1176+
} {
1177+
t.Run(fmt.Sprintf("%s→%v", tt.v.Type().Yql(), reflect.ValueOf(tt.dst).Type().Elem()),
1178+
func(t *testing.T) {
1179+
if err := CastTo(tt.v, tt.dst); (err != nil) != tt.error {
1180+
t.Errorf("castTo() error = %v, want %v", err, tt.error)
1181+
} else if !reflect.DeepEqual(tt.dst, tt.result) {
1182+
t.Errorf("castTo() result = %+v, want %+v",
1183+
reflect.ValueOf(tt.dst).Elem(),
1184+
reflect.ValueOf(tt.result).Elem(),
1185+
)
1186+
}
1187+
},
1188+
)
1189+
}
1190+
}
1191+
1192+
func TestCastSet(t *testing.T) {
1193+
for _, tt := range []struct {
1194+
v Value
1195+
dst interface{}
1196+
result interface{}
1197+
error bool
1198+
}{
1199+
{
1200+
v: SetValue(Int32Value(12), Int32Value(21), Int32Value(56)),
1201+
dst: func(v []int64) *[]int64 { return &v }([]int64{}),
1202+
result: func(v []int64) *[]int64 { return &v }([]int64{12, 21, 56}),
1203+
error: false,
1204+
},
1205+
{
1206+
v: SetValue(Int32Value(12), Int32Value(21), Int32Value(56)),
1207+
dst: func(v []int64) *[]int64 { return &v }([]int64{17}),
1208+
result: func(v []int64) *[]int64 { return &v }([]int64{12, 21, 56}),
1209+
error: false,
1210+
},
1211+
{
1212+
v: SetValue(BytesValue([]byte("test")), BytesValue([]byte("test2"))),
1213+
dst: func(v []string) *[]string { return &v }([]string{}),
1214+
result: func(v []string) *[]string { return &v }([]string{"test", "test2"}),
1215+
error: false,
1216+
},
1217+
{
1218+
v: SetValue(BytesValue([]byte("test")), BytesValue([]byte("test2"))),
1219+
dst: func(v []string) *[]string { return &v }([]string{"list"}),
1220+
result: func(v []string) *[]string { return &v }([]string{"test", "test2"}),
1221+
error: false,
1222+
},
1223+
} {
1224+
t.Run(fmt.Sprintf("%s→%v", tt.v.Type().Yql(), reflect.ValueOf(tt.dst).Type().Elem()),
1225+
func(t *testing.T) {
1226+
if err := CastTo(tt.v, tt.dst); (err != nil) != tt.error {
1227+
t.Errorf("castTo() error = %v, want %v", err, tt.error)
1228+
} else if !reflect.DeepEqual(tt.dst, tt.result) {
1229+
t.Errorf("castTo() result = %+v, want %+v",
1230+
reflect.ValueOf(tt.dst).Elem(),
1231+
reflect.ValueOf(tt.result).Elem(),
1232+
)
1233+
}
1234+
},
1235+
)
1236+
}
1237+
}
1238+
1239+
func TestCastStruct(t *testing.T) {
1240+
type defaultStruct struct {
1241+
ID int32 `sql:"id"`
1242+
Str string `sql:"myStr"`
1243+
}
1244+
for _, tt := range []struct {
1245+
v Value
1246+
dst interface{}
1247+
result interface{}
1248+
error bool
1249+
}{
1250+
{
1251+
v: StructValue(
1252+
StructValueField{Name: "id", V: Int32Value(123)},
1253+
StructValueField{Name: "myStr", V: BytesValue([]byte("myStr123"))},
1254+
),
1255+
dst: func(v defaultStruct) *defaultStruct { return &v }(defaultStruct{1, "myStr1"}),
1256+
result: func(v defaultStruct) *defaultStruct { return &v }(defaultStruct{123, "myStr123"}),
1257+
error: false,
1258+
},
1259+
{
1260+
v: StructValue(
1261+
StructValueField{Name: "id", V: Int32Value(12)},
1262+
StructValueField{Name: "myStr", V: BytesValue([]byte("myStr12"))},
1263+
),
1264+
dst: func(v defaultStruct) *defaultStruct { return &v }(defaultStruct{}),
1265+
result: func(v defaultStruct) *defaultStruct { return &v }(defaultStruct{12, "myStr12"}),
1266+
error: false,
1267+
},
1268+
} {
1269+
t.Run(fmt.Sprintf("%s→%v", tt.v.Type().Yql(), reflect.ValueOf(tt.dst).Type().Elem()),
1270+
func(t *testing.T) {
1271+
if err := CastTo(tt.v, tt.dst); (err != nil) != tt.error {
1272+
t.Errorf("castTo() error = %v, want %v", err, tt.error)
1273+
} else if !reflect.DeepEqual(tt.dst, tt.result) {
1274+
t.Errorf("castTo() result = %+v, want %+v",
1275+
reflect.ValueOf(tt.dst).Elem(),
1276+
reflect.ValueOf(tt.result).Elem(),
1277+
)
1278+
}
1279+
},
1280+
)
1281+
}
1282+
}
1283+
11451284
func TestCastOtherTypes(t *testing.T) {
11461285
for _, tt := range []struct {
11471286
v Value

tests/integration/sugar_unmarhall_result_set_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,83 @@ func TestSugarUnmarshallResultSet(t *testing.T) {
4747
require.EqualValues(t, 43, many[1].ID)
4848
require.EqualValues(t, "myStr43", many[1].Str)
4949
})
50+
t.Run("ListField", func(t *testing.T) {
51+
type myStruct struct {
52+
Ids []int32 `sql:"ids"`
53+
}
54+
55+
rows, err := db.Query().QueryResultSet(ctx, `SELECT AsList(42, 43) as ids`)
56+
57+
many, err := sugar.UnmarshallResultSet[myStruct](rows)
58+
require.NoError(t, err)
59+
require.Len(t, many, 1)
60+
require.NotNil(t, many[0])
61+
require.Len(t, many[0].Ids, 2)
62+
require.EqualValues(t, 42, many[0].Ids[0])
63+
require.EqualValues(t, 43, many[0].Ids[1])
64+
})
65+
t.Run("SetField", func(t *testing.T) {
66+
type myStruct struct {
67+
Ids []int32 `sql:"ids"`
68+
}
69+
70+
rows, err := db.Query().QueryResultSet(ctx, `SELECT AsSet(42, 43) as ids`)
71+
72+
many, err := sugar.UnmarshallResultSet[myStruct](rows)
73+
require.NoError(t, err)
74+
require.Len(t, many, 1)
75+
require.NotNil(t, many[0])
76+
require.Len(t, many[0].Ids, 2)
77+
require.EqualValues(t, 42, many[0].Ids[0])
78+
require.EqualValues(t, 43, many[0].Ids[1])
79+
})
80+
t.Run("StructField", func(t *testing.T) {
81+
type myStructField struct {
82+
ID int32 `sql:"id"`
83+
Str string `sql:"myStr"`
84+
}
85+
type myStruct struct {
86+
ID int32 `sql:"id"`
87+
Str string `sql:"myStr"`
88+
StructField myStructField `sql:"structColumn"`
89+
}
90+
91+
rows, err := db.Query().QueryResultSet(ctx, `
92+
SELECT 42 as id, "myStr42" as myStr, AsStruct(22 as id, "myStr22" as myStr) as structColumn
93+
`)
94+
95+
many, err := sugar.UnmarshallResultSet[myStruct](rows)
96+
require.NoError(t, err)
97+
require.Len(t, many, 1)
98+
require.NotNil(t, many[0])
99+
require.EqualValues(t, 42, many[0].ID)
100+
require.EqualValues(t, "myStr42", many[0].Str)
101+
require.EqualValues(t, 22, many[0].StructField.ID)
102+
require.EqualValues(t, "myStr22", many[0].StructField.Str)
103+
})
104+
t.Run("ListOfStructsField", func(t *testing.T) {
105+
type myStructField struct {
106+
ID int32 `sql:"id"`
107+
Str string `sql:"myStr"`
108+
}
109+
type myStruct struct {
110+
Values []myStructField `sql:"values"`
111+
}
112+
113+
rows, err := db.Query().QueryResultSet(ctx,
114+
`SELECT AsList(AsStruct(22 as id, "myStr22" as myStr), AsStruct(42 as id, "myStr42" as myStr)) as values`,
115+
)
116+
117+
many, err := sugar.UnmarshallResultSet[myStruct](rows)
118+
require.NoError(t, err)
119+
require.Len(t, many, 1)
120+
require.NotNil(t, many[0])
121+
require.Len(t, many[0].Values, 2)
122+
require.EqualValues(t, 22, many[0].Values[0].ID)
123+
require.EqualValues(t, "myStr22", many[0].Values[0].Str)
124+
require.EqualValues(t, 42, many[0].Values[1].ID)
125+
require.EqualValues(t, "myStr42", many[0].Values[1].Str)
126+
})
50127
t.Run("UnexpectedColumn", func(t *testing.T) {
51128
type myStruct struct {
52129
ID int32 `sql:"id"`

0 commit comments

Comments
 (0)