Skip to content

Commit 218b17f

Browse files
committed
Include multi & exec in pipeline hook
1 parent 49a0c8c commit 218b17f

File tree

8 files changed

+85
-85
lines changed

8 files changed

+85
-85
lines changed

bench_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ func BenchmarkClusterPing(b *testing.B) {
254254
}
255255
defer stopCluster(cluster)
256256

257-
client := cluster.clusterClient(redisClusterOptions())
257+
client := cluster.newClusterClient(redisClusterOptions())
258258
defer client.Close()
259259

260260
b.ResetTimer()
@@ -280,7 +280,7 @@ func BenchmarkClusterSetString(b *testing.B) {
280280
}
281281
defer stopCluster(cluster)
282282

283-
client := cluster.clusterClient(redisClusterOptions())
283+
client := cluster.newClusterClient(redisClusterOptions())
284284
defer client.Close()
285285

286286
value := string(bytes.Repeat([]byte{'1'}, 10000))
@@ -308,7 +308,7 @@ func BenchmarkClusterReloadState(b *testing.B) {
308308
}
309309
defer stopCluster(cluster)
310310

311-
client := cluster.clusterClient(redisClusterOptions())
311+
client := cluster.newClusterClient(redisClusterOptions())
312312
defer client.Close()
313313

314314
b.ResetTimer()

cluster.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
773773

774774
if ask {
775775
pipe := node.Client.Pipeline()
776-
_ = pipe.Process(NewCmd("ASKING"))
776+
_ = pipe.Process(NewCmd("asking"))
777777
_ = pipe.Process(cmd)
778778
_, lastErr = pipe.ExecContext(ctx)
779779
_ = pipe.Close()
@@ -1200,7 +1200,7 @@ func (c *ClusterClient) checkMovedErr(
12001200
}
12011201

12021202
if ask {
1203-
failedCmds.Add(node, NewCmd("ASKING"), cmd)
1203+
failedCmds.Add(node, NewCmd("asking"), cmd)
12041204
return true
12051205
}
12061206

@@ -1294,35 +1294,39 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
12941294
func (c *ClusterClient) _processTxPipelineNode(
12951295
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
12961296
) error {
1297-
return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1297+
return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
12981298
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
12991299
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1300-
return txPipelineWriteMulti(wr, cmds)
1300+
return writeCmds(wr, cmds)
13011301
})
13021302
if err != nil {
13031303
return err
13041304
}
13051305

13061306
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1307-
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
1307+
statusCmd := cmds[0].(*StatusCmd)
1308+
// Trim multi and exec.
1309+
cmds = cmds[1 : len(cmds)-1]
1310+
1311+
err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds)
13081312
if err != nil {
13091313
moved, ask, addr := isMovedError(err)
13101314
if moved || ask {
13111315
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
13121316
}
13131317
return err
13141318
}
1319+
13151320
return pipelineReadCmds(rd, cmds)
13161321
})
13171322
})
13181323
})
13191324
}
13201325

13211326
func (c *ClusterClient) txPipelineReadQueued(
1322-
rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
1327+
rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap,
13231328
) error {
13241329
// Parse queued replies.
1325-
var statusCmd StatusCmd
13261330
if err := statusCmd.readReply(rd); err != nil {
13271331
return err
13281332
}
@@ -1374,7 +1378,7 @@ func (c *ClusterClient) cmdsMoved(
13741378

13751379
if ask {
13761380
for _, cmd := range cmds {
1377-
failedCmds.Add(node, NewCmd("ASKING"), cmd)
1381+
failedCmds.Add(node, NewCmd("asking"), cmd)
13781382
}
13791383
return nil
13801384
}

cluster_test.go

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ func (s *clusterScenario) addrs() []string {
4747
return addrs
4848
}
4949

50-
func (s *clusterScenario) clusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient {
50+
func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient {
5151
opt.Addrs = s.addrs()
5252
return redis.NewClusterClient(opt)
5353

5454
}
5555

56-
func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient {
57-
client := s.clusterClientUnsafe(opt)
56+
func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient {
57+
client := s.newClusterClientUnsafe(opt)
5858

5959
err := eventually(func() error {
6060
if opt.ClusterSlots != nil {
@@ -529,14 +529,11 @@ var _ = Describe("ClusterClient", func() {
529529
})
530530

531531
It("supports Process hook", func() {
532-
var masters []*redis.Client
533-
534532
err := client.Ping().Err()
535533
Expect(err).NotTo(HaveOccurred())
536534

537-
err = client.ForEachMaster(func(master *redis.Client) error {
538-
masters = append(masters, master)
539-
return master.Ping().Err()
535+
err = client.ForEachNode(func(node *redis.Client) error {
536+
return node.Ping().Err()
540537
})
541538
Expect(err).NotTo(HaveOccurred())
542539

@@ -556,7 +553,7 @@ var _ = Describe("ClusterClient", func() {
556553
}
557554
client.AddHook(clusterHook)
558555

559-
masterHook := &hook{
556+
nodeHook := &hook{
560557
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
561558
Expect(cmd.String()).To(Equal("ping: "))
562559
stack = append(stack, "shard.BeforeProcess")
@@ -569,9 +566,10 @@ var _ = Describe("ClusterClient", func() {
569566
},
570567
}
571568

572-
for _, master := range masters {
573-
master.AddHook(masterHook)
574-
}
569+
_ = client.ForEachNode(func(node *redis.Client) error {
570+
node.AddHook(nodeHook)
571+
return nil
572+
})
575573

576574
err = client.Ping().Err()
577575
Expect(err).NotTo(HaveOccurred())
@@ -584,19 +582,16 @@ var _ = Describe("ClusterClient", func() {
584582

585583
clusterHook.beforeProcess = nil
586584
clusterHook.afterProcess = nil
587-
masterHook.beforeProcess = nil
588-
masterHook.afterProcess = nil
585+
nodeHook.beforeProcess = nil
586+
nodeHook.afterProcess = nil
589587
})
590588

591589
It("supports Pipeline hook", func() {
592-
var masters []*redis.Client
593-
594590
err := client.Ping().Err()
595591
Expect(err).NotTo(HaveOccurred())
596592

597-
err = client.ForEachMaster(func(master *redis.Client) error {
598-
masters = append(masters, master)
599-
return master.Ping().Err()
593+
err = client.ForEachNode(func(node *redis.Client) error {
594+
return node.Ping().Err()
600595
})
601596
Expect(err).NotTo(HaveOccurred())
602597

@@ -617,8 +612,8 @@ var _ = Describe("ClusterClient", func() {
617612
},
618613
})
619614

620-
for _, master := range masters {
621-
master.AddHook(&hook{
615+
_ = client.ForEachNode(func(node *redis.Client) error {
616+
node.AddHook(&hook{
622617
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
623618
Expect(cmds).To(HaveLen(1))
624619
Expect(cmds[0].String()).To(Equal("ping: "))
@@ -632,7 +627,8 @@ var _ = Describe("ClusterClient", func() {
632627
return nil
633628
},
634629
})
635-
}
630+
return nil
631+
})
636632

637633
_, err = client.Pipelined(func(pipe redis.Pipeliner) error {
638634
pipe.Ping()
@@ -648,14 +644,11 @@ var _ = Describe("ClusterClient", func() {
648644
})
649645

650646
It("supports TxPipeline hook", func() {
651-
var masters []*redis.Client
652-
653647
err := client.Ping().Err()
654648
Expect(err).NotTo(HaveOccurred())
655649

656-
err = client.ForEachMaster(func(master *redis.Client) error {
657-
masters = append(masters, master)
658-
return master.Ping().Err()
650+
err = client.ForEachNode(func(node *redis.Client) error {
651+
return node.Ping().Err()
659652
})
660653
Expect(err).NotTo(HaveOccurred())
661654

@@ -676,22 +669,23 @@ var _ = Describe("ClusterClient", func() {
676669
},
677670
})
678671

679-
for _, master := range masters {
680-
master.AddHook(&hook{
672+
_ = client.ForEachNode(func(node *redis.Client) error {
673+
node.AddHook(&hook{
681674
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
682-
Expect(cmds).To(HaveLen(1))
683-
Expect(cmds[0].String()).To(Equal("ping: "))
675+
Expect(cmds).To(HaveLen(3))
676+
Expect(cmds[1].String()).To(Equal("ping: "))
684677
stack = append(stack, "shard.BeforeProcessPipeline")
685678
return ctx, nil
686679
},
687680
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
688-
Expect(cmds).To(HaveLen(1))
689-
Expect(cmds[0].String()).To(Equal("ping: PONG"))
681+
Expect(cmds).To(HaveLen(3))
682+
Expect(cmds[1].String()).To(Equal("ping: PONG"))
690683
stack = append(stack, "shard.AfterProcessPipeline")
691684
return nil
692685
},
693686
})
694-
}
687+
return nil
688+
})
695689

696690
_, err = client.TxPipelined(func(pipe redis.Pipeliner) error {
697691
pipe.Ping()
@@ -710,7 +704,7 @@ var _ = Describe("ClusterClient", func() {
710704
Describe("ClusterClient", func() {
711705
BeforeEach(func() {
712706
opt = redisClusterOptions()
713-
client = cluster.clusterClient(opt)
707+
client = cluster.newClusterClient(opt)
714708

715709
err := client.ForEachMaster(func(master *redis.Client) error {
716710
return master.FlushDB().Err()
@@ -733,7 +727,7 @@ var _ = Describe("ClusterClient", func() {
733727
It("returns an error when there are no attempts left", func() {
734728
opt := redisClusterOptions()
735729
opt.MaxRedirects = -1
736-
client := cluster.clusterClient(opt)
730+
client := cluster.newClusterClient(opt)
737731

738732
Eventually(func() error {
739733
return client.SwapNodes("A")
@@ -885,7 +879,7 @@ var _ = Describe("ClusterClient", func() {
885879
opt = redisClusterOptions()
886880
opt.MinRetryBackoff = 250 * time.Millisecond
887881
opt.MaxRetryBackoff = time.Second
888-
client = cluster.clusterClient(opt)
882+
client = cluster.newClusterClient(opt)
889883

890884
err := client.ForEachMaster(func(master *redis.Client) error {
891885
return master.FlushDB().Err()
@@ -935,7 +929,7 @@ var _ = Describe("ClusterClient", func() {
935929
BeforeEach(func() {
936930
opt = redisClusterOptions()
937931
opt.RouteByLatency = true
938-
client = cluster.clusterClient(opt)
932+
client = cluster.newClusterClient(opt)
939933

940934
err := client.ForEachMaster(func(master *redis.Client) error {
941935
return master.FlushDB().Err()
@@ -991,7 +985,7 @@ var _ = Describe("ClusterClient", func() {
991985
}}
992986
return slots, nil
993987
}
994-
client = cluster.clusterClient(opt)
988+
client = cluster.newClusterClient(opt)
995989

996990
err := client.ForEachMaster(func(master *redis.Client) error {
997991
return master.FlushDB().Err()
@@ -1045,7 +1039,7 @@ var _ = Describe("ClusterClient", func() {
10451039
}}
10461040
return slots, nil
10471041
}
1048-
client = cluster.clusterClient(opt)
1042+
client = cluster.newClusterClient(opt)
10491043

10501044
err := client.ForEachMaster(func(master *redis.Client) error {
10511045
return master.FlushDB().Err()
@@ -1137,7 +1131,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() {
11371131
opt.ReadTimeout = 250 * time.Millisecond
11381132
opt.WriteTimeout = 250 * time.Millisecond
11391133
opt.MaxRedirects = 1
1140-
client = cluster.clusterClientUnsafe(opt)
1134+
client = cluster.newClusterClientUnsafe(opt)
11411135
})
11421136

11431137
AfterEach(func() {
@@ -1206,7 +1200,7 @@ var _ = Describe("ClusterClient timeout", func() {
12061200
opt.ReadTimeout = 250 * time.Millisecond
12071201
opt.WriteTimeout = 250 * time.Millisecond
12081202
opt.MaxRedirects = 1
1209-
client = cluster.clusterClient(opt)
1203+
client = cluster.newClusterClient(opt)
12101204

12111205
err := client.ForEachNode(func(client *redis.Client) error {
12121206
return client.ClientPause(pause).Err()

race_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() {
299299

300300
BeforeEach(func() {
301301
opt := redisClusterOptions()
302-
client = cluster.clusterClient(opt)
302+
client = cluster.newClusterClient(opt)
303303

304304
C, N = 10, 1000
305305
if testing.Short() {

0 commit comments

Comments
 (0)