Skip to content

Commit bdcf7a4

Browse files
committed
Set an error returned from the hook on the Cmd
1 parent db45a82 commit bdcf7a4

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

redis.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ func (hs hooks) process(
5151
) error {
5252
ctx, err := hs.beforeProcess(ctx, cmd)
5353
if err != nil {
54+
cmd.setErr(err)
5455
return err
5556
}
5657

5758
cmdErr := fn(ctx, cmd)
5859

59-
err = hs.afterProcess(ctx, cmd)
60-
if err != nil {
60+
if err := hs.afterProcess(ctx, cmd); err != nil {
61+
cmd.setErr(err)
6162
return err
6263
}
6364

@@ -91,13 +92,14 @@ func (hs hooks) processPipeline(
9192
) error {
9293
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
9394
if err != nil {
95+
setCmdsErr(cmds, err)
9496
return err
9597
}
9698

9799
cmdsErr := fn(ctx, cmds)
98100

99-
err = hs.afterProcessPipeline(ctx, cmds)
100-
if err != nil {
101+
if err := hs.afterProcessPipeline(ctx, cmds); err != nil {
102+
setCmdsErr(cmds, err)
101103
return err
102104
}
103105

redis_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package redis_test
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"net"
8+
"testing"
79
"time"
810

911
"github.com/go-redis/redis/v7"
@@ -12,6 +14,39 @@ import (
1214
. "github.com/onsi/gomega"
1315
)
1416

17+
type redisHookError struct {
18+
redis.Hook
19+
}
20+
21+
var _ redis.Hook = redisHookError{}
22+
23+
func (redisHookError) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
24+
return ctx, nil
25+
}
26+
27+
func (redisHookError) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
28+
return errors.New("hook error")
29+
}
30+
31+
func TestHookError(t *testing.T) {
32+
rdb := redis.NewClient(&redis.Options{
33+
Addr: ":6379",
34+
})
35+
rdb.AddHook(redisHookError{})
36+
37+
err := rdb.Ping().Err()
38+
if err == nil {
39+
t.Fatalf("got nil, expected an error")
40+
}
41+
42+
wanted := "hook error"
43+
if err.Error() != wanted {
44+
t.Fatalf(`got %q, wanted %q`, err, wanted)
45+
}
46+
}
47+
48+
//------------------------------------------------------------------------------
49+
1550
var _ = Describe("Client", func() {
1651
var client *redis.Client
1752

0 commit comments

Comments
 (0)