Skip to content

Commit 8608f93

Browse files
committed
pubsub pool
1 parent e9f32f0 commit 8608f93

15 files changed

+744
-188
lines changed

example/pubsub/main.go

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ import (
1111
)
1212

1313
var ctx = context.Background()
14-
var consStopped = false
1514

15+
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
16+
// It was used to find regressions in pool management in hitless mode.
17+
// Please don't use it as a reference for how to use pubsub.
1618
func main() {
1719
wg := &sync.WaitGroup{}
1820
rdb := redis.NewClient(&redis.Options{
19-
Addr: ":6379",
21+
Addr: ":6379",
22+
HitlessUpgrades: true,
2023
})
2124
_ = rdb.FlushDB(ctx).Err()
2225

@@ -30,21 +33,22 @@ func main() {
3033
if err != nil {
3134
panic(err)
3235
}
33-
if err := rdb.Set(ctx, "prods", "0", 0).Err(); err != nil {
36+
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
3437
panic(err)
3538
}
36-
if err := rdb.Set(ctx, "cons", "0", 0).Err(); err != nil {
39+
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
3740
panic(err)
3841
}
39-
if err := rdb.Set(ctx, "cntr", "0", 0).Err(); err != nil {
42+
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
4043
panic(err)
4144
}
42-
if err := rdb.Set(ctx, "recs", "0", 0).Err(); err != nil {
45+
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
4346
panic(err)
4447
}
45-
fmt.Println("cntr", rdb.Get(ctx, "cntr").Val())
46-
fmt.Println("recs", rdb.Get(ctx, "recs").Val())
48+
fmt.Println("published", rdb.Get(ctx, "published").Val())
49+
fmt.Println("received", rdb.Get(ctx, "received").Val())
4750
subCtx, cancelSubCtx := context.WithCancel(ctx)
51+
pubCtx, cancelPublishers := context.WithCancel(ctx)
4852
for i := 0; i < 10; i++ {
4953
wg.Add(1)
5054
go subscribe(subCtx, rdb, "test", i, wg)
@@ -54,32 +58,39 @@ func main() {
5458
time.Sleep(time.Second)
5559
subCtx, cancelSubCtx = context.WithCancel(ctx)
5660
for i := 0; i < 10; i++ {
57-
if err := rdb.Incr(ctx, "prods").Err(); err != nil {
61+
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
5862
panic(err)
5963
}
6064
wg.Add(1)
61-
go floodThePool(subCtx, rdb, wg)
65+
go floodThePool(pubCtx, rdb, wg)
6266
}
6367

6468
for i := 0; i < 500; i++ {
65-
if err := rdb.Incr(ctx, "cons").Err(); err != nil {
69+
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
6670
panic(err)
6771
}
6872
wg.Add(1)
6973
go subscribe(subCtx, rdb, "test2", i, wg)
7074
}
75+
time.Sleep(5 * time.Second)
76+
fmt.Println("canceling publishers")
77+
cancelPublishers()
7178
time.Sleep(10 * time.Second)
72-
fmt.Println("canceling")
79+
fmt.Println("canceling subscribers")
7380
cancelSubCtx()
7481
wg.Wait()
75-
cntr, err := rdb.Get(ctx, "cntr").Result()
76-
recs, err := rdb.Get(ctx, "recs").Result()
77-
prods, err := rdb.Get(ctx, "prods").Result()
78-
cons, err := rdb.Get(ctx, "cons").Result()
79-
fmt.Printf("cntr: %s\n", cntr)
80-
fmt.Printf("recs: %s\n", recs)
81-
fmt.Printf("prods: %s\n", prods)
82-
fmt.Printf("cons: %s\n", cons)
82+
published, err := rdb.Get(ctx, "published").Result()
83+
received, err := rdb.Get(ctx, "received").Result()
84+
publishers, err := rdb.Get(ctx, "publishers").Result()
85+
subscribers, err := rdb.Get(ctx, "subscribers").Result()
86+
fmt.Printf("publishers: %s\n", publishers)
87+
fmt.Printf("published: %s\n", published)
88+
fmt.Printf("subscribers: %s\n", subscribers)
89+
fmt.Printf("received: %s\n", received)
90+
publishedInt, err := rdb.Get(ctx, "published").Int()
91+
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
92+
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
93+
8394
time.Sleep(2 * time.Second)
8495
}
8596

@@ -88,8 +99,6 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
8899
for {
89100
select {
90101
case <-ctx.Done():
91-
fmt.Println("floodThePool stopping")
92-
consStopped = true
93102
return
94103
default:
95104
}
@@ -99,7 +108,7 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
99108
//log.Println("publish error:", err)
100109
}
101110

102-
err = rdb.Incr(ctx, "cntr").Err()
111+
err = rdb.Incr(ctx, "published").Err()
103112
if err != nil {
104113
// noop
105114
//log.Println("incr error:", err)
@@ -110,36 +119,24 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
110119

111120
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
112121
defer wg.Done()
113-
defer fmt.Printf("subscriber %d stopping\n", subscriberId)
114122
rec := rdb.Subscribe(ctx, topic)
115123
recChan := rec.Channel()
116124
for {
117125
select {
118126
case <-ctx.Done():
119127
rec.Close()
120-
if subscriberId == 199 {
121-
fmt.Printf("subscriber %d done\n", subscriberId)
122-
}
123128
return
124129
default:
125130
select {
126131
case <-ctx.Done():
127132
rec.Close()
128-
if subscriberId == 199 {
129-
fmt.Printf("subscriber %d done\n", subscriberId)
130-
}
131133
return
132134
case msg := <-recChan:
133-
err := rdb.Incr(ctx, "recs").Err()
135+
err := rdb.Incr(ctx, "received").Err()
134136
if err != nil {
135137
log.Println("incr error:", err)
136138
}
137-
if consStopped {
138-
fmt.Printf("subscriber %d received %s\n", subscriberId, msg.Payload)
139-
}
140-
if subscriberId == 199 {
141-
fmt.Printf("subscriber %d received %s\n", subscriberId, msg.Payload)
142-
}
139+
_ = msg // Use the message to avoid unused variable warning
143140
}
144141
}
145142
}

hitless/config.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,14 @@ func isPrivateIP(addr string) bool {
314314
// Simplified check for common private IP ranges
315315
// 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
316316
// This is a simplified implementation; a full implementation would parse the IP properly
317-
if len(addr) >= 3 {
318-
if addr[:3] == "10." || addr[:8] == "192.168." {
319-
return true
320-
}
321-
if len(addr) >= 7 && addr[:7] == "172.16." {
322-
return true
323-
}
317+
if len(addr) >= 3 && addr[:3] == "10." {
318+
return true
319+
}
320+
if len(addr) >= 8 && addr[:8] == "192.168." {
321+
return true
322+
}
323+
if len(addr) >= 7 && addr[:7] == "172.16." {
324+
return true
324325
}
325326
return false
326327
}

hitless/redis_connection_processor_test.go

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
6060
return nil, errors.New("not implemented")
6161
}
6262

63-
func (mp *mockPool) GetPubSub(ctx context.Context) (*pool.Conn, error) {
64-
return nil, errors.New("not implemented")
65-
}
66-
6763
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
6864
// Not implemented for testing
6965
}
@@ -115,10 +111,13 @@ func TestRedisConnectionProcessor(t *testing.T) {
115111
t.Fatalf("Failed to mark connection for handoff: %v", err)
116112
}
117113

118-
// Set a mock initialization function
119-
initConnCalled := false
114+
// Set a mock initialization function with synchronization
115+
initConnCalled := make(chan bool, 1)
120116
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
121-
initConnCalled = true
117+
select {
118+
case initConnCalled <- true:
119+
default:
120+
}
122121
return nil
123122
}
124123
conn.SetInitConnFunc(initConnFunc)
@@ -142,22 +141,44 @@ func TestRedisConnectionProcessor(t *testing.T) {
142141
t.Error("Connection should be in pending handoffs map")
143142
}
144143

145-
// Wait for handoff to complete
146-
time.Sleep(100 * time.Millisecond)
144+
// Wait for initialization to be called (indicates handoff started)
145+
select {
146+
case <-initConnCalled:
147+
// Good, initialization was called
148+
case <-time.After(1 * time.Second):
149+
t.Fatal("Timeout waiting for initialization function to be called")
150+
}
151+
152+
// Wait for handoff to complete with proper timeout and polling
153+
timeout := time.After(2 * time.Second)
154+
ticker := time.NewTicker(10 * time.Millisecond)
155+
defer ticker.Stop()
156+
157+
handoffCompleted := false
158+
for !handoffCompleted {
159+
select {
160+
case <-timeout:
161+
t.Fatal("Timeout waiting for handoff to complete")
162+
case <-ticker.C:
163+
if _, pending := processor.pending.Load(conn); !pending {
164+
handoffCompleted = true
165+
}
166+
}
167+
}
147168

148169
// Verify handoff completed (removed from pending map)
149170
if _, pending := processor.pending.Load(conn); pending {
150171
t.Error("Connection should be removed from pending map after handoff")
151172
}
152173

153-
// Verify handoff state was cleared
154-
if conn.ShouldHandoff() {
155-
t.Error("Connection should not be marked for handoff after successful handoff")
174+
// Verify connection is usable again
175+
if !conn.IsUsable() {
176+
t.Error("Connection should be usable after successful handoff")
156177
}
157178

158-
// Verify initialization was called
159-
if !initConnCalled {
160-
t.Error("InitConn should have been called")
179+
// Verify handoff state is cleared
180+
if conn.ShouldHandoff() {
181+
t.Error("Connection should not be marked for handoff after completion")
161182
}
162183
})
163184

@@ -236,8 +257,22 @@ func TestRedisConnectionProcessor(t *testing.T) {
236257
t.Error("Connection should not be removed when queuing handoff")
237258
}
238259

239-
// Wait for handoff to complete and fail
240-
time.Sleep(100 * time.Millisecond)
260+
// Wait for handoff to complete and fail with proper timeout and polling
261+
timeout := time.After(2 * time.Second)
262+
ticker := time.NewTicker(10 * time.Millisecond)
263+
defer ticker.Stop()
264+
265+
handoffCompleted := false
266+
for !handoffCompleted {
267+
select {
268+
case <-timeout:
269+
t.Fatal("Timeout waiting for failed handoff to complete")
270+
case <-ticker.C:
271+
if _, pending := processor.pending.Load(conn); !pending {
272+
handoffCompleted = true
273+
}
274+
}
275+
}
241276

242277
// Connection should be removed from pending map after failed handoff
243278
if _, pending := processor.pending.Load(conn); pending {
@@ -468,8 +503,22 @@ func TestRedisConnectionProcessor(t *testing.T) {
468503
t.Error("Connection should be pooled after handoff")
469504
}
470505

471-
// Wait for handoff to complete
472-
time.Sleep(50 * time.Millisecond)
506+
// Wait for handoff to complete with proper timeout and polling
507+
timeout := time.After(1 * time.Second)
508+
ticker := time.NewTicker(5 * time.Millisecond)
509+
defer ticker.Stop()
510+
511+
handoffCompleted := false
512+
for !handoffCompleted {
513+
select {
514+
case <-timeout:
515+
t.Fatal("Timeout waiting for handoff to complete")
516+
case <-ticker.C:
517+
if _, pending := processor.pending.Load(conn); !pending {
518+
handoffCompleted = true
519+
}
520+
}
521+
}
473522

474523
// Verify relaxed timeout is set with deadline
475524
if !conn.HasRelaxedTimeout() {
@@ -626,17 +675,15 @@ func TestRedisConnectionProcessor(t *testing.T) {
626675
}
627676
}
628677

629-
// Verify queue has items but capacity remains static
630-
currentQueueSize := len(processor.handoffQueue)
631-
if currentQueueSize == 0 {
632-
t.Error("Expected some items in queue after processing connections")
633-
}
634-
678+
// Verify queue capacity remains static (the main purpose of this test)
635679
finalCapacity := cap(processor.handoffQueue)
636680
if finalCapacity != 50 {
637681
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
638682
}
639683

684+
// Note: We don't check queue size here because workers process items quickly
685+
// The important thing is that the capacity remains static regardless of pool size
686+
currentQueueSize := len(processor.handoffQueue)
640687
t.Logf("Static queue test completed - Capacity: %d, Current size: %d",
641688
finalCapacity, currentQueueSize)
642689
})
@@ -738,7 +785,21 @@ func TestRedisConnectionProcessor(t *testing.T) {
738785
}
739786

740787
// Wait for the handoff to complete (it happens asynchronously)
741-
time.Sleep(50 * time.Millisecond)
788+
timeout := time.After(1 * time.Second)
789+
ticker := time.NewTicker(5 * time.Millisecond)
790+
defer ticker.Stop()
791+
792+
handoffCompleted := false
793+
for !handoffCompleted {
794+
select {
795+
case <-timeout:
796+
t.Fatal("Timeout waiting for handoff to complete")
797+
case <-ticker.C:
798+
if _, pending := processor.pending.Load(conn); !pending {
799+
handoffCompleted = true
800+
}
801+
}
802+
}
742803

743804
// Verify that relaxed timeout was applied to the new connection
744805
if !conn.HasRelaxedTimeout() {

internal/pool/conn.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ type Conn struct {
4545

4646
Inited bool
4747
pooled bool
48-
isPubSub bool
4948
createdAt time.Time
5049
expiresAt time.Time
5150

0 commit comments

Comments
 (0)