Skip to content

Commit 35f6ccd

Browse files
authored
Merge pull request #1631 from knadh/feat-decoder
Add redis.Scan() to scan results from redis maps into structs.
2 parents bf010a7 + 9527245 commit 35f6ccd

File tree

6 files changed

+496
-0
lines changed

6 files changed

+496
-0
lines changed

command.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"time"
99

1010
"github.com/go-redis/redis/v8/internal"
11+
"github.com/go-redis/redis/v8/internal/hscan"
1112
"github.com/go-redis/redis/v8/internal/proto"
1213
"github.com/go-redis/redis/v8/internal/util"
1314
)
@@ -371,6 +372,26 @@ func (cmd *SliceCmd) String() string {
371372
return cmdString(cmd, cmd.val)
372373
}
373374

375+
// Scan scans the results from the map into a destination struct. The map keys
376+
// are matched in the Redis struct fields by the `redis:"field"` tag.
377+
func (cmd *SliceCmd) Scan(dst interface{}) error {
378+
if cmd.err != nil {
379+
return cmd.err
380+
}
381+
382+
// Pass the list of keys and values.
383+
// Skip the first two args for: HMGET key
384+
var args []interface{}
385+
if cmd.args[0] == "hmget" {
386+
args = cmd.args[2:]
387+
} else {
388+
// Otherwise, it's: MGET field field ...
389+
args = cmd.args[1:]
390+
}
391+
392+
return hscan.Scan(dst, args, cmd.val)
393+
}
394+
374395
func (cmd *SliceCmd) readReply(rd *proto.Reader) error {
375396
v, err := rd.ReadArrayReply(sliceParser)
376397
if err != nil {
@@ -917,6 +938,27 @@ func (cmd *StringStringMapCmd) String() string {
917938
return cmdString(cmd, cmd.val)
918939
}
919940

941+
// Scan scans the results from the map into a destination struct. The map keys
942+
// are matched in the Redis struct fields by the `redis:"field"` tag.
943+
func (cmd *StringStringMapCmd) Scan(dst interface{}) error {
944+
if cmd.err != nil {
945+
return cmd.err
946+
}
947+
948+
strct, err := hscan.Struct(dst)
949+
if err != nil {
950+
return err
951+
}
952+
953+
for k, v := range cmd.val {
954+
if err := strct.Scan(k, v); err != nil {
955+
return err
956+
}
957+
}
958+
959+
return nil
960+
}
961+
920962
func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error {
921963
_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
922964
cmd.val = make(map[string]string, n/2)

commands_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,22 @@ var _ = Describe("Commands", func() {
11341134
Expect(mGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil}))
11351135
})
11361136

1137+
It("should scan Mget", func() {
1138+
err := client.MSet(ctx, "key1", "hello1", "key2", 123).Err()
1139+
Expect(err).NotTo(HaveOccurred())
1140+
1141+
res := client.MGet(ctx, "key1", "key2", "_")
1142+
Expect(res.Err()).NotTo(HaveOccurred())
1143+
1144+
type data struct {
1145+
Key1 string `redis:"key1"`
1146+
Key2 int `redis:"key2"`
1147+
}
1148+
var d data
1149+
Expect(res.Scan(&d)).NotTo(HaveOccurred())
1150+
Expect(d).To(Equal(data{Key1: "hello1", Key2: 123}))
1151+
})
1152+
11371153
It("should MSetNX", func() {
11381154
mSetNX := client.MSetNX(ctx, "key1", "hello1", "key2", "hello2")
11391155
Expect(mSetNX.Err()).NotTo(HaveOccurred())
@@ -1375,6 +1391,22 @@ var _ = Describe("Commands", func() {
13751391
Expect(m).To(Equal(map[string]string{"key1": "hello1", "key2": "hello2"}))
13761392
})
13771393

1394+
It("should scan", func() {
1395+
err := client.HMSet(ctx, "hash", "key1", "hello1", "key2", 123).Err()
1396+
Expect(err).NotTo(HaveOccurred())
1397+
1398+
res := client.HGetAll(ctx, "hash")
1399+
Expect(res.Err()).NotTo(HaveOccurred())
1400+
1401+
type data struct {
1402+
Key1 string `redis:"key1"`
1403+
Key2 int `redis:"key2"`
1404+
}
1405+
var d data
1406+
Expect(res.Scan(&d)).NotTo(HaveOccurred())
1407+
Expect(d).To(Equal(data{Key1: "hello1", Key2: 123}))
1408+
})
1409+
13781410
It("should HIncrBy", func() {
13791411
hSet := client.HSet(ctx, "hash", "key", "5")
13801412
Expect(hSet.Err()).NotTo(HaveOccurred())

example_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,73 @@ func ExampleClient_ScanType() {
276276
// Output: found 33 keys
277277
}
278278

279+
// ExampleStringStringMapCmd_Scan shows how to scan the results of a map fetch
280+
// into a struct.
281+
func ExampleStringStringMapCmd_Scan() {
282+
rdb.FlushDB(ctx)
283+
err := rdb.HMSet(ctx, "map",
284+
"name", "hello",
285+
"count", 123,
286+
"correct", true).Err()
287+
if err != nil {
288+
panic(err)
289+
}
290+
291+
// Get the map. The same approach works for HmGet().
292+
res := rdb.HGetAll(ctx, "map")
293+
if res.Err() != nil {
294+
panic(err)
295+
}
296+
297+
type data struct {
298+
Name string `redis:"name"`
299+
Count int `redis:"count"`
300+
Correct bool `redis:"correct"`
301+
}
302+
303+
// Scan the results into the struct.
304+
var d data
305+
if err := res.Scan(&d); err != nil {
306+
panic(err)
307+
}
308+
309+
fmt.Println(d)
310+
// Output: {hello 123 true}
311+
}
312+
313+
// ExampleSliceCmd_Scan shows how to scan the results of a multi key fetch
314+
// into a struct.
315+
func ExampleSliceCmd_Scan() {
316+
rdb.FlushDB(ctx)
317+
err := rdb.MSet(ctx,
318+
"name", "hello",
319+
"count", 123,
320+
"correct", true).Err()
321+
if err != nil {
322+
panic(err)
323+
}
324+
325+
res := rdb.MGet(ctx, "name", "count", "empty", "correct")
326+
if res.Err() != nil {
327+
panic(err)
328+
}
329+
330+
type data struct {
331+
Name string `redis:"name"`
332+
Count int `redis:"count"`
333+
Correct bool `redis:"correct"`
334+
}
335+
336+
// Scan the results into the struct.
337+
var d data
338+
if err := res.Scan(&d); err != nil {
339+
panic(err)
340+
}
341+
342+
fmt.Println(d)
343+
// Output: {hello 123 true}
344+
}
345+
279346
func ExampleClient_Pipelined() {
280347
var incr *redis.IntCmd
281348
_, err := rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error {

internal/hscan/hscan.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package hscan
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"reflect"
7+
"strconv"
8+
)
9+
10+
// decoderFunc represents decoding functions for default built-in types.
11+
type decoderFunc func(reflect.Value, string) error
12+
13+
var (
14+
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1).
15+
decoders = []decoderFunc{
16+
reflect.Bool: decodeBool,
17+
reflect.Int: decodeInt,
18+
reflect.Int8: decodeInt,
19+
reflect.Int16: decodeInt,
20+
reflect.Int32: decodeInt,
21+
reflect.Int64: decodeInt,
22+
reflect.Uint: decodeUint,
23+
reflect.Uint8: decodeUint,
24+
reflect.Uint16: decodeUint,
25+
reflect.Uint32: decodeUint,
26+
reflect.Uint64: decodeUint,
27+
reflect.Float32: decodeFloat,
28+
reflect.Float64: decodeFloat,
29+
reflect.Complex64: decodeUnsupported,
30+
reflect.Complex128: decodeUnsupported,
31+
reflect.Array: decodeUnsupported,
32+
reflect.Chan: decodeUnsupported,
33+
reflect.Func: decodeUnsupported,
34+
reflect.Interface: decodeUnsupported,
35+
reflect.Map: decodeUnsupported,
36+
reflect.Ptr: decodeUnsupported,
37+
reflect.Slice: decodeSlice,
38+
reflect.String: decodeString,
39+
reflect.Struct: decodeUnsupported,
40+
reflect.UnsafePointer: decodeUnsupported,
41+
}
42+
43+
// Global map of struct field specs that is populated once for every new
44+
// struct type that is scanned. This caches the field types and the corresponding
45+
// decoder functions to avoid iterating through struct fields on subsequent scans.
46+
globalStructMap = newStructMap()
47+
)
48+
49+
func Struct(dst interface{}) (StructValue, error) {
50+
v := reflect.ValueOf(dst)
51+
52+
// The dstination to scan into should be a struct pointer.
53+
if v.Kind() != reflect.Ptr || v.IsNil() {
54+
return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst)
55+
}
56+
57+
v = v.Elem()
58+
if v.Kind() != reflect.Struct {
59+
return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst)
60+
}
61+
62+
return StructValue{
63+
spec: globalStructMap.get(v.Type()),
64+
value: v,
65+
}, nil
66+
}
67+
68+
// Scan scans the results from a key-value Redis map result set to a destination struct.
69+
// The Redis keys are matched to the struct's field with the `redis` tag.
70+
func Scan(dst interface{}, keys []interface{}, vals []interface{}) error {
71+
if len(keys) != len(vals) {
72+
return errors.New("args should have the same number of keys and vals")
73+
}
74+
75+
strct, err := Struct(dst)
76+
if err != nil {
77+
return err
78+
}
79+
80+
// Iterate through the (key, value) sequence.
81+
for i := 0; i < len(vals); i++ {
82+
key, ok := keys[i].(string)
83+
if !ok {
84+
continue
85+
}
86+
87+
val, ok := vals[i].(string)
88+
if !ok {
89+
continue
90+
}
91+
92+
if err := strct.Scan(key, val); err != nil {
93+
return err
94+
}
95+
}
96+
97+
return nil
98+
}
99+
100+
func decodeBool(f reflect.Value, s string) error {
101+
b, err := strconv.ParseBool(s)
102+
if err != nil {
103+
return err
104+
}
105+
f.SetBool(b)
106+
return nil
107+
}
108+
109+
func decodeInt(f reflect.Value, s string) error {
110+
v, err := strconv.ParseInt(s, 10, 0)
111+
if err != nil {
112+
return err
113+
}
114+
f.SetInt(v)
115+
return nil
116+
}
117+
118+
func decodeUint(f reflect.Value, s string) error {
119+
v, err := strconv.ParseUint(s, 10, 0)
120+
if err != nil {
121+
return err
122+
}
123+
f.SetUint(v)
124+
return nil
125+
}
126+
127+
func decodeFloat(f reflect.Value, s string) error {
128+
v, err := strconv.ParseFloat(s, 0)
129+
if err != nil {
130+
return err
131+
}
132+
f.SetFloat(v)
133+
return nil
134+
}
135+
136+
func decodeString(f reflect.Value, s string) error {
137+
f.SetString(s)
138+
return nil
139+
}
140+
141+
func decodeSlice(f reflect.Value, s string) error {
142+
// []byte slice ([]uint8).
143+
if f.Type().Elem().Kind() == reflect.Uint8 {
144+
f.SetBytes([]byte(s))
145+
}
146+
return nil
147+
}
148+
149+
func decodeUnsupported(v reflect.Value, s string) error {
150+
return fmt.Errorf("redis.Scan(unsupported %s)", v.Type())
151+
}

0 commit comments

Comments
 (0)