Skip to content

Commit 49a0c8c

Browse files
committed
Add test for ring and cluster hooks
1 parent 2e3402d commit 49a0c8c

File tree

7 files changed

+433
-44
lines changed

7 files changed

+433
-44
lines changed

cluster.go

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
779779
_ = pipe.Close()
780780
ask = false
781781
} else {
782-
lastErr = node.Client._process(ctx, cmd)
782+
lastErr = node.Client.ProcessContext(ctx, cmd)
783783
}
784784

785785
// If there is no error - we are done.
@@ -840,6 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
840840

841841
var wg sync.WaitGroup
842842
errCh := make(chan error, 1)
843+
843844
for _, master := range state.Masters {
844845
wg.Add(1)
845846
go func(node *clusterNode) {
@@ -853,6 +854,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
853854
}
854855
}(master)
855856
}
857+
856858
wg.Wait()
857859

858860
select {
@@ -873,6 +875,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
873875

874876
var wg sync.WaitGroup
875877
errCh := make(chan error, 1)
878+
876879
for _, slave := range state.Slaves {
877880
wg.Add(1)
878881
go func(node *clusterNode) {
@@ -886,6 +889,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
886889
}
887890
}(slave)
888891
}
892+
889893
wg.Wait()
890894

891895
select {
@@ -906,6 +910,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
906910

907911
var wg sync.WaitGroup
908912
errCh := make(chan error, 1)
913+
909914
worker := func(node *clusterNode) {
910915
defer wg.Done()
911916
err := fn(node.Client)
@@ -927,6 +932,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
927932
}
928933

929934
wg.Wait()
935+
930936
select {
931937
case err := <-errCh:
932938
return err
@@ -1068,18 +1074,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
10681074
go func(node *clusterNode, cmds []Cmder) {
10691075
defer wg.Done()
10701076

1071-
err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
1072-
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1073-
return writeCmds(wr, cmds)
1074-
})
1075-
if err != nil {
1076-
return err
1077-
}
1078-
1079-
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1080-
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
1081-
})
1082-
})
1077+
err := c._processPipelineNode(ctx, node, cmds, failedCmds)
10831078
if err == nil {
10841079
return
10851080
}
@@ -1142,6 +1137,25 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
11421137
return true
11431138
}
11441139

1140+
func (c *ClusterClient) _processPipelineNode(
1141+
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
1142+
) error {
1143+
return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1144+
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
1145+
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1146+
return writeCmds(wr, cmds)
1147+
})
1148+
if err != nil {
1149+
return err
1150+
}
1151+
1152+
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1153+
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
1154+
})
1155+
})
1156+
})
1157+
}
1158+
11451159
func (c *ClusterClient) pipelineReadCmds(
11461160
node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
11471161
) error {
@@ -1243,26 +1257,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
12431257
go func(node *clusterNode, cmds []Cmder) {
12441258
defer wg.Done()
12451259

1246-
err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
1247-
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1248-
return txPipelineWriteMulti(wr, cmds)
1249-
})
1250-
if err != nil {
1251-
return err
1252-
}
1253-
1254-
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1255-
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
1256-
if err != nil {
1257-
moved, ask, addr := isMovedError(err)
1258-
if moved || ask {
1259-
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
1260-
}
1261-
return err
1262-
}
1263-
return pipelineReadCmds(rd, cmds)
1264-
})
1265-
})
1260+
err := c._processTxPipelineNode(ctx, node, cmds, failedCmds)
12661261
if err == nil {
12671262
return
12681263
}
@@ -1296,6 +1291,33 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
12961291
return cmdsMap
12971292
}
12981293

1294+
func (c *ClusterClient) _processTxPipelineNode(
1295+
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
1296+
) error {
1297+
return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1298+
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
1299+
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1300+
return txPipelineWriteMulti(wr, cmds)
1301+
})
1302+
if err != nil {
1303+
return err
1304+
}
1305+
1306+
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1307+
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
1308+
if err != nil {
1309+
moved, ask, addr := isMovedError(err)
1310+
if moved || ask {
1311+
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
1312+
}
1313+
return err
1314+
}
1315+
return pipelineReadCmds(rd, cmds)
1316+
})
1317+
})
1318+
})
1319+
}
1320+
12991321
func (c *ClusterClient) txPipelineReadQueued(
13001322
rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
13011323
) error {

cluster_test.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,184 @@ var _ = Describe("ClusterClient", func() {
527527
err := pubsub.Ping()
528528
Expect(err).NotTo(HaveOccurred())
529529
})
530+
531+
It("supports Process hook", func() {
532+
var masters []*redis.Client
533+
534+
err := client.Ping().Err()
535+
Expect(err).NotTo(HaveOccurred())
536+
537+
err = client.ForEachMaster(func(master *redis.Client) error {
538+
masters = append(masters, master)
539+
return master.Ping().Err()
540+
})
541+
Expect(err).NotTo(HaveOccurred())
542+
543+
var stack []string
544+
545+
clusterHook := &hook{
546+
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
547+
Expect(cmd.String()).To(Equal("ping: "))
548+
stack = append(stack, "cluster.BeforeProcess")
549+
return ctx, nil
550+
},
551+
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
552+
Expect(cmd.String()).To(Equal("ping: PONG"))
553+
stack = append(stack, "cluster.AfterProcess")
554+
return nil
555+
},
556+
}
557+
client.AddHook(clusterHook)
558+
559+
masterHook := &hook{
560+
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
561+
Expect(cmd.String()).To(Equal("ping: "))
562+
stack = append(stack, "shard.BeforeProcess")
563+
return ctx, nil
564+
},
565+
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
566+
Expect(cmd.String()).To(Equal("ping: PONG"))
567+
stack = append(stack, "shard.AfterProcess")
568+
return nil
569+
},
570+
}
571+
572+
for _, master := range masters {
573+
master.AddHook(masterHook)
574+
}
575+
576+
err = client.Ping().Err()
577+
Expect(err).NotTo(HaveOccurred())
578+
Expect(stack).To(Equal([]string{
579+
"cluster.BeforeProcess",
580+
"shard.BeforeProcess",
581+
"shard.AfterProcess",
582+
"cluster.AfterProcess",
583+
}))
584+
585+
clusterHook.beforeProcess = nil
586+
clusterHook.afterProcess = nil
587+
masterHook.beforeProcess = nil
588+
masterHook.afterProcess = nil
589+
})
590+
591+
It("supports Pipeline hook", func() {
592+
var masters []*redis.Client
593+
594+
err := client.Ping().Err()
595+
Expect(err).NotTo(HaveOccurred())
596+
597+
err = client.ForEachMaster(func(master *redis.Client) error {
598+
masters = append(masters, master)
599+
return master.Ping().Err()
600+
})
601+
Expect(err).NotTo(HaveOccurred())
602+
603+
var stack []string
604+
605+
client.AddHook(&hook{
606+
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
607+
Expect(cmds).To(HaveLen(1))
608+
Expect(cmds[0].String()).To(Equal("ping: "))
609+
stack = append(stack, "cluster.BeforeProcessPipeline")
610+
return ctx, nil
611+
},
612+
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
613+
Expect(cmds).To(HaveLen(1))
614+
Expect(cmds[0].String()).To(Equal("ping: PONG"))
615+
stack = append(stack, "cluster.AfterProcessPipeline")
616+
return nil
617+
},
618+
})
619+
620+
for _, master := range masters {
621+
master.AddHook(&hook{
622+
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
623+
Expect(cmds).To(HaveLen(1))
624+
Expect(cmds[0].String()).To(Equal("ping: "))
625+
stack = append(stack, "shard.BeforeProcessPipeline")
626+
return ctx, nil
627+
},
628+
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
629+
Expect(cmds).To(HaveLen(1))
630+
Expect(cmds[0].String()).To(Equal("ping: PONG"))
631+
stack = append(stack, "shard.AfterProcessPipeline")
632+
return nil
633+
},
634+
})
635+
}
636+
637+
_, err = client.Pipelined(func(pipe redis.Pipeliner) error {
638+
pipe.Ping()
639+
return nil
640+
})
641+
Expect(err).NotTo(HaveOccurred())
642+
Expect(stack).To(Equal([]string{
643+
"cluster.BeforeProcessPipeline",
644+
"shard.BeforeProcessPipeline",
645+
"shard.AfterProcessPipeline",
646+
"cluster.AfterProcessPipeline",
647+
}))
648+
})
649+
650+
It("supports TxPipeline hook", func() {
651+
var masters []*redis.Client
652+
653+
err := client.Ping().Err()
654+
Expect(err).NotTo(HaveOccurred())
655+
656+
err = client.ForEachMaster(func(master *redis.Client) error {
657+
masters = append(masters, master)
658+
return master.Ping().Err()
659+
})
660+
Expect(err).NotTo(HaveOccurred())
661+
662+
var stack []string
663+
664+
client.AddHook(&hook{
665+
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
666+
Expect(cmds).To(HaveLen(1))
667+
Expect(cmds[0].String()).To(Equal("ping: "))
668+
stack = append(stack, "cluster.BeforeProcessPipeline")
669+
return ctx, nil
670+
},
671+
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
672+
Expect(cmds).To(HaveLen(1))
673+
Expect(cmds[0].String()).To(Equal("ping: PONG"))
674+
stack = append(stack, "cluster.AfterProcessPipeline")
675+
return nil
676+
},
677+
})
678+
679+
for _, master := range masters {
680+
master.AddHook(&hook{
681+
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
682+
Expect(cmds).To(HaveLen(1))
683+
Expect(cmds[0].String()).To(Equal("ping: "))
684+
stack = append(stack, "shard.BeforeProcessPipeline")
685+
return ctx, nil
686+
},
687+
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
688+
Expect(cmds).To(HaveLen(1))
689+
Expect(cmds[0].String()).To(Equal("ping: PONG"))
690+
stack = append(stack, "shard.AfterProcessPipeline")
691+
return nil
692+
},
693+
})
694+
}
695+
696+
_, err = client.TxPipelined(func(pipe redis.Pipeliner) error {
697+
pipe.Ping()
698+
return nil
699+
})
700+
Expect(err).NotTo(HaveOccurred())
701+
Expect(stack).To(Equal([]string{
702+
"cluster.BeforeProcessPipeline",
703+
"shard.BeforeProcessPipeline",
704+
"shard.AfterProcessPipeline",
705+
"cluster.AfterProcessPipeline",
706+
}))
707+
})
530708
}
531709

532710
Describe("ClusterClient", func() {

command.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
type Cmder interface {
1616
Name() string
1717
Args() []interface{}
18+
String() string
1819
stringArg(int) string
1920

2021
readTimeout() *time.Duration
@@ -152,6 +153,10 @@ func NewCmd(args ...interface{}) *Cmd {
152153
}
153154
}
154155

156+
func (cmd *Cmd) String() string {
157+
return cmdString(cmd, cmd.val)
158+
}
159+
155160
func (cmd *Cmd) Val() interface{} {
156161
return cmd.val
157162
}
@@ -160,7 +165,7 @@ func (cmd *Cmd) Result() (interface{}, error) {
160165
return cmd.val, cmd.err
161166
}
162167

163-
func (cmd *Cmd) String() (string, error) {
168+
func (cmd *Cmd) Text() (string, error) {
164169
if cmd.err != nil {
165170
return "", cmd.err
166171
}

example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ func Example_customCommand() {
447447
}
448448

449449
func Example_customCommand2() {
450-
v, err := rdb.Do("get", "key_does_not_exist").String()
450+
v, err := rdb.Do("get", "key_does_not_exist").Text()
451451
fmt.Printf("%q %s", v, err)
452452
// Output: "" redis: nil
453453
}

0 commit comments

Comments
 (0)