Skip to content

Commit 69005e6

Browse files
committed
up
1 parent 00d9848 commit 69005e6

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

internal/proto/writer.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"net"
8+
"reflect"
89
"strconv"
910
"time"
1011

@@ -140,6 +141,36 @@ func (w *Writer) WriteArg(v interface{}) error {
140141
return w.bytes(b)
141142
case net.IP:
142143
return w.bytes(v)
144+
default:
145+
return w.writeArgExtra(v)
146+
}
147+
}
148+
149+
func (w *Writer) writeArgExtra(v interface{}) error {
150+
var (
151+
rfValue = reflect.ValueOf(v)
152+
rfKind = rfValue.Kind()
153+
)
154+
155+
switch rfKind {
156+
case reflect.Bool:
157+
if rfValue.Bool() {
158+
return w.int(1)
159+
}
160+
return w.int(0)
161+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
162+
return w.int(rfValue.Int())
163+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
164+
return w.uint(rfValue.Uint())
165+
case reflect.Float32, reflect.Float64:
166+
return w.float(rfValue.Float())
167+
case reflect.String:
168+
return w.string(rfValue.String())
169+
case reflect.Slice:
170+
if rfValue.Type().Elem().Kind() == reflect.Uint8 {
171+
return w.bytes(rfValue.Bytes())
172+
}
173+
fallthrough
143174
default:
144175
return fmt.Errorf(
145176
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)

redis_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,20 @@ var _ = Describe("Client", func() {
362362

363363
Expect(ip2).To(Equal(ip))
364364
})
365+
366+
It("should set and scan custom type", func() {
367+
type customString string
368+
369+
val := customString("hello")
370+
err := client.Set(ctx, "custom", val, 0).Err()
371+
Expect(err).NotTo(HaveOccurred())
372+
373+
var val2 customString
374+
err = client.Get(ctx, "custom").Scan(&val2)
375+
Expect(err).NotTo(HaveOccurred())
376+
377+
Expect(val2).To(Equal(val))
378+
})
365379
})
366380

367381
var _ = Describe("Client timeout", func() {

0 commit comments

Comments
 (0)