Skip to content

Commit 0c5e583

Browse files
committed
pubsub pool still outperforms track/untrack
1 parent a4090d4 commit 0c5e583

File tree

4 files changed

+133
-56
lines changed

4 files changed

+133
-56
lines changed

internal/pool/pool.go

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pool
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"log"
78
"net"
89
"sync"
@@ -51,6 +52,8 @@ type Stats struct {
5152
TotalConns uint32 // number of total connections in the pool
5253
IdleConns uint32 // number of idle connections in the pool
5354
StaleConns uint32 // number of stale connections removed from the pool
55+
56+
PubSubStats PubSubStats
5457
}
5558

5659
type Pooler interface {
@@ -104,7 +107,7 @@ type ConnPool struct {
104107
queue chan struct{}
105108

106109
connsMu sync.Mutex
107-
conns []*Conn
110+
conns map[uint64]*Conn
108111
idleConns []*Conn
109112

110113
poolSize atomic.Int32
@@ -127,13 +130,10 @@ func NewConnPool(opt *Options) *ConnPool {
127130
cfg: opt,
128131

129132
queue: make(chan struct{}, opt.PoolSize),
130-
conns: make([]*Conn, 0, opt.PoolSize),
133+
conns: make(map[uint64]*Conn),
131134
idleConns: make([]*Conn, 0, opt.PoolSize),
132135
}
133136

134-
// Initialize hooks system
135-
p.initializeHooks()
136-
137137
// Only create MinIdleConns if explicitly requested (> 0)
138138
// This avoids creating connections during pool initialization for tests
139139
if opt.MinIdleConns > 0 {
@@ -156,17 +156,18 @@ func (p *ConnPool) AddPoolHook(hook PoolHook) {
156156
p.initializeHooks()
157157
}
158158
p.hookManager.AddHook(hook)
159+
p.hookManager = nil
159160
}
160161

161162
// RemovePoolHook removes a pool hook from the pool.
162163
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
163164
if p.hookManager != nil {
164165
p.hookManager.RemoveHook(hook)
165166
}
167+
p.hookManager = nil
166168
}
167169

168170
func (p *ConnPool) checkMinIdleConns() {
169-
170171
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
171172
return
172173
}
@@ -186,10 +187,8 @@ func (p *ConnPool) checkMinIdleConns() {
186187
go func() {
187188
defer func() {
188189
if err := recover(); err != nil {
189-
p.connsMu.Lock()
190190
p.poolSize.Add(-1)
191191
p.idleConnsLen.Add(-1)
192-
p.connsMu.Unlock()
193192

194193
p.freeTurn()
195194
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
@@ -198,10 +197,8 @@ func (p *ConnPool) checkMinIdleConns() {
198197

199198
err := p.addIdleConn()
200199
if err != nil && err != ErrClosed {
201-
p.connsMu.Lock()
202200
p.poolSize.Add(-1)
203201
p.idleConnsLen.Add(-1)
204-
p.connsMu.Unlock()
205202
}
206203
p.freeTurn()
207204
}()
@@ -232,7 +229,7 @@ func (p *ConnPool) addIdleConn() error {
232229
return ErrClosed
233230
}
234231

235-
p.conns = append(p.conns, cn)
232+
p.conns[cn.GetID()] = cn
236233
p.idleConns = append(p.idleConns, cn)
237234
return nil
238235
}
@@ -250,12 +247,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
250247
return nil, ErrClosed
251248
}
252249

253-
p.connsMu.Lock()
254250
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
255-
p.connsMu.Unlock()
256251
return nil, ErrPoolExhausted
257252
}
258-
p.connsMu.Unlock()
259253

260254
cn, err := p.dialConn(ctx, pooled)
261255
if err != nil {
@@ -265,15 +259,14 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
265259
// This is essential for normal pool operations
266260
cn.SetUsable(true)
267261

268-
p.connsMu.Lock()
269-
defer p.connsMu.Unlock()
270-
271262
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
272263
_ = cn.Close()
273264
return nil, ErrPoolExhausted
274265
}
275266

276-
p.conns = append(p.conns, cn)
267+
p.connsMu.Lock()
268+
p.conns[cn.GetID()] = cn
269+
defer p.connsMu.Unlock()
277270
if pooled {
278271
// If pool is full remove the cn on next Put.
279272
currentPoolSize := p.poolSize.Load()
@@ -307,6 +300,7 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
307300

308301
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
309302
cn.pooled = pooled
303+
fmt.Printf("New conn %d, pooled: %v\n", cn.GetID(), cn.pooled)
310304
if p.cfg.ConnMaxLifetime > 0 {
311305
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
312306
} else {
@@ -372,7 +366,6 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
372366
now := time.Now()
373367
attempts := 0
374368
for {
375-
376369
if attempts >= getAttempts {
377370
log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts)
378371
break
@@ -477,6 +470,7 @@ func (p *ConnPool) popIdle() (*Conn, error) {
477470
if p.closed() {
478471
return nil, ErrClosed
479472
}
473+
480474
n := len(p.idleConns)
481475
if n == 0 {
482476
return nil, nil
@@ -570,28 +564,30 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
570564

571565
var shouldCloseConn bool
572566

573-
p.connsMu.Lock()
574-
575567
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
576568
// unusable conns are expected to become usable at some point (background process is reconnecting them)
577569
// put them at the opposite end of the queue
578570
if !cn.IsUsable() {
579571
if p.cfg.PoolFIFO {
572+
p.connsMu.Lock()
580573
p.idleConns = append(p.idleConns, cn)
574+
p.connsMu.Unlock()
581575
} else {
576+
p.connsMu.Lock()
582577
p.idleConns = append([]*Conn{cn}, p.idleConns...)
578+
p.connsMu.Unlock()
583579
}
584580
} else {
581+
p.connsMu.Lock()
585582
p.idleConns = append(p.idleConns, cn)
583+
p.connsMu.Unlock()
586584
}
587585
p.idleConnsLen.Add(1)
588586
} else {
589-
p.removeConn(cn)
587+
p.removeConnWithLock(cn)
590588
shouldCloseConn = true
591589
}
592590

593-
p.connsMu.Unlock()
594-
595591
p.freeTurn()
596592

597593
if shouldCloseConn {
@@ -600,6 +596,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
600596
}
601597

602598
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
599+
internal.Logger.Printf(context.Background(), "Removing connection %d from pool: %v", cn.GetID(), reason)
603600
p.removeConnWithLock(cn)
604601
p.freeTurn()
605602
_ = p.closeConn(cn)
@@ -617,17 +614,7 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
617614
}
618615

619616
func (p *ConnPool) removeConn(cn *Conn) {
620-
for i, c := range p.conns {
621-
if c == cn {
622-
p.conns = append(p.conns[:i], p.conns[i+1:]...)
623-
if cn.pooled {
624-
p.poolSize.Add(-1)
625-
// Immediately check for minimum idle connections when a pooled connection is removed
626-
p.checkMinIdleConns()
627-
}
628-
break
629-
}
630-
}
617+
delete(p.conns, cn.GetID())
631618
atomic.AddUint32(&p.stats.StaleConns, 1)
632619
}
633620

@@ -743,18 +730,12 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
743730
func (p *ConnPool) TrackConn(cn *Conn) {
744731
p.connsMu.Lock()
745732
p.poolSize.Add(1)
746-
p.conns = append(p.conns, cn)
733+
p.conns[cn.GetID()] = cn
747734
p.connsMu.Unlock()
748735
}
749736

750737
func (p *ConnPool) UntrackConn(cn *Conn) {
751738
p.connsMu.Lock()
752-
for i, c := range p.conns {
753-
if c == cn {
754-
p.conns = append(p.conns[:i], p.conns[i+1:]...)
755-
p.poolSize.Add(-1)
756-
break
757-
}
758-
}
739+
delete(p.conns, cn.GetID())
759740
p.connsMu.Unlock()
760741
}

internal/pool/pubsub.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package pool
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"sync"
8+
"sync/atomic"
9+
)
10+
11+
type PubSubStats struct {
12+
Created uint32
13+
Untracked uint32
14+
Active uint32
15+
}
16+
17+
// PubSubPool manages a pool of PubSub connections.
18+
type PubSubPool struct {
19+
opt *Options
20+
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
21+
22+
// Map to track active PubSub connections
23+
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
24+
closed atomic.Bool
25+
stats PubSubStats
26+
}
27+
28+
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
29+
return &PubSubPool{
30+
opt: opt,
31+
netDialer: netDialer,
32+
}
33+
}
34+
35+
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
36+
if p.closed.Load() {
37+
return nil, errors.New("pubsub pool is closed")
38+
}
39+
40+
netConn, err := p.netDialer(ctx, network, addr)
41+
if err != nil {
42+
return nil, err
43+
}
44+
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
45+
atomic.AddUint32(&p.stats.Created, 1)
46+
return cn, nil
47+
48+
}
49+
50+
func (p *PubSubPool) TrackConn(cn *Conn) {
51+
atomic.AddUint32(&p.stats.Active, 1)
52+
p.activeConns.Store(cn.GetID(), cn)
53+
}
54+
55+
func (p *PubSubPool) UntrackConn(cn *Conn) {
56+
atomic.AddUint32(&p.stats.Active, ^uint32(0))
57+
atomic.AddUint32(&p.stats.Untracked, 1)
58+
p.activeConns.Delete(cn.GetID())
59+
}
60+
61+
func (p *PubSubPool) Close() error {
62+
p.closed.Store(true)
63+
p.activeConns.Range(func(key, value interface{}) bool {
64+
cn := value.(*Conn)
65+
_ = cn.Close()
66+
return true
67+
})
68+
return nil
69+
}
70+
71+
func (p *PubSubPool) Stats() *PubSubStats {
72+
// load stats atomically
73+
return &PubSubStats{
74+
Created: atomic.LoadUint32(&p.stats.Created),
75+
Untracked: atomic.LoadUint32(&p.stats.Untracked),
76+
Active: atomic.LoadUint32(&p.stats.Active),
77+
}
78+
}

options.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,3 +660,21 @@ func newConnPool(
660660
PushNotificationsEnabled: opt.Protocol == 3,
661661
})
662662
}
663+
664+
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
665+
) *pool.PubSubPool {
666+
return pool.NewPubSubPool(&pool.Options{
667+
PoolFIFO: opt.PoolFIFO,
668+
PoolSize: int32(opt.PoolSize),
669+
PoolTimeout: opt.PoolTimeout,
670+
DialTimeout: opt.DialTimeout,
671+
MinIdleConns: int32(opt.MinIdleConns),
672+
MaxIdleConns: int32(opt.MaxIdleConns),
673+
MaxActiveConns: int32(opt.MaxActiveConns),
674+
ConnMaxIdleTime: opt.ConnMaxIdleTime,
675+
ConnMaxLifetime: opt.ConnMaxLifetime,
676+
ReadBufferSize: 32 * 1024,
677+
WriteBufferSize: 32 * 1024,
678+
PushNotificationsEnabled: opt.Protocol == 3,
679+
}, dialer)
680+
}

0 commit comments

Comments
 (0)