Skip to content

Commit e03396e

Browse files
committed
lock inside the listeners collection
1 parent 011ef96 commit e03396e

File tree

3 files changed

+53
-24
lines changed

3 files changed

+53
-24
lines changed

auth/conn_reauth_credentials_listener.go renamed to internal/auth/conn_reauth_credentials_listener.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"runtime"
55
"time"
66

7+
auth2 "github.com/redis/go-redis/v9/auth"
78
"github.com/redis/go-redis/v9/internal/pool"
89
)
910

@@ -17,7 +18,7 @@ import (
1718
// - checkUsableTimeout: the timeout to wait for the connection to be usable - default is 1 second.
1819
type ConnReAuthCredentialsListener struct {
1920
// reAuth is called when the credentials are updated.
20-
reAuth func(conn *pool.Conn, credentials Credentials) error
21+
reAuth func(conn *pool.Conn, credentials auth2.Credentials) error
2122
// onErr is called when an error occurs.
2223
onErr func(conn *pool.Conn, err error)
2324
// conn is the connection to re-authenticate.
@@ -31,7 +32,7 @@ type ConnReAuthCredentialsListener struct {
3132
// OnNext is called when the credentials are updated.
3233
// It calls the reAuth function with the new credentials.
3334
// If the reAuth function returns an error, it calls the onErr function with the error.
34-
func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
35+
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth2.Credentials) {
3536
if c.conn.IsClosed() {
3637
return
3738
}
@@ -101,7 +102,7 @@ func (c *ConnReAuthCredentialsListener) SetCheckUsableTimeout(timeout time.Durat
101102

102103
// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
103104
// Implements the auth.CredentialsListener interface.
104-
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
105+
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials auth2.Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
105106
return &ConnReAuthCredentialsListener{
106107
conn: conn,
107108
reAuth: reAuth,
@@ -111,4 +112,4 @@ func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Co
111112
}
112113

113114
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
114-
var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
115+
var _ auth2.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)

internal/auth/cred_listeners.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package auth
2+
3+
import (
4+
"sync"
5+
6+
auth2 "github.com/redis/go-redis/v9/auth"
7+
"github.com/redis/go-redis/v9/internal/pool"
8+
)
9+
10+
type CredentialsListeners struct {
11+
listeners map[*pool.Conn]auth2.CredentialsListener
12+
lock sync.RWMutex
13+
}
14+
15+
func NewCredentialsListeners() *CredentialsListeners {
16+
return &CredentialsListeners{}
17+
}
18+
19+
func (c *CredentialsListeners) Add(poolCn *pool.Conn, listener auth2.CredentialsListener) {
20+
c.lock.Lock()
21+
defer c.lock.Unlock()
22+
if c.listeners == nil {
23+
c.listeners = make(map[*pool.Conn]auth2.CredentialsListener)
24+
}
25+
c.listeners[poolCn] = listener
26+
}
27+
28+
func (c *CredentialsListeners) Get(poolCn *pool.Conn) (auth2.CredentialsListener, bool) {
29+
c.lock.RLock()
30+
defer c.lock.RUnlock()
31+
listener, ok := c.listeners[poolCn]
32+
return listener, ok
33+
}
34+
35+
func (c *CredentialsListeners) Remove(poolCn *pool.Conn) {
36+
c.lock.Lock()
37+
defer c.lock.Unlock()
38+
delete(c.listeners, poolCn)
39+
}

redis.go

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/redis/go-redis/v9/auth"
1313
"github.com/redis/go-redis/v9/internal"
14+
auth2 "github.com/redis/go-redis/v9/internal/auth"
1415
"github.com/redis/go-redis/v9/internal/hscan"
1516
"github.com/redis/go-redis/v9/internal/pool"
1617
"github.com/redis/go-redis/v9/internal/proto"
@@ -225,8 +226,8 @@ type baseClient struct {
225226
maintNotificationsManager *maintnotifications.Manager
226227
maintNotificationsManagerLock sync.RWMutex
227228

228-
credListeners map[*pool.Conn]auth.CredentialsListener
229-
credListenersLock sync.RWMutex
229+
// thread-safe map of pool connections to credentials listeners
230+
credListeners *auth2.CredentialsListeners
230231
}
231232

232233
func (c *baseClient) clone() *baseClient {
@@ -304,17 +305,13 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
304305
// The credentials listener is stored in a map, so that it can be reused for multiple connections.
305306
// The credentials listener is removed from the map when the connection is closed.
306307
func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.CredentialsListener, func()) {
307-
c.credListenersLock.RLock()
308-
credListener, ok := c.credListeners[poolCn]
309-
c.credListenersLock.RUnlock()
308+
credListener, ok := c.credListeners.Get(poolCn)
310309
if ok {
311310
return credListener, func() {
312-
c.removeCredListener(poolCn)
311+
c.credListeners.Remove(poolCn)
313312
}
314313
}
315-
c.credListenersLock.Lock()
316-
defer c.credListenersLock.Unlock()
317-
newCredListener := auth.NewConnReAuthCredentialsListener(
314+
newCredListener := auth2.NewConnReAuthCredentialsListener(
318315
poolCn,
319316
c.reAuthConnection(),
320317
c.onAuthenticationErr(),
@@ -333,18 +330,12 @@ func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.Cred
333330
} else {
334331
newCredListener.SetCheckUsableTimeout(c.opt.PoolTimeout)
335332
}
336-
c.credListeners[poolCn] = newCredListener
333+
c.credListeners.Add(poolCn, newCredListener)
337334
return newCredListener, func() {
338-
c.removeCredListener(poolCn)
335+
c.credListeners.Remove(poolCn)
339336
}
340337
}
341338

342-
func (c *baseClient) removeCredListener(poolCn *pool.Conn) {
343-
c.credListenersLock.Lock()
344-
defer c.credListenersLock.Unlock()
345-
delete(c.credListeners, poolCn)
346-
}
347-
348339
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
349340
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
350341
var err error
@@ -1005,9 +996,7 @@ func NewClient(opt *Options) *Client {
1005996
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
1006997
}
1007998

1008-
if c.opt.StreamingCredentialsProvider != nil {
1009-
c.credListeners = make(map[*pool.Conn]auth.CredentialsListener)
1010-
}
999+
c.credListeners = auth2.NewCredentialsListeners()
10111000

10121001
// Initialize maintnotifications first if enabled and protocol is RESP3
10131002
if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 {

0 commit comments

Comments
 (0)