diff --git a/auth/reauth_credentials_listener.go b/auth/reauth_credentials_listener.go index 40076a0b1..f4b319838 100644 --- a/auth/reauth_credentials_listener.go +++ b/auth/reauth_credentials_listener.go @@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on } // Ensure ReAuthCredentialsListener implements the CredentialsListener interface. -var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) \ No newline at end of file diff --git a/error.go b/error.go index 8013de44a..7273313b5 100644 --- a/error.go +++ b/error.go @@ -108,10 +108,12 @@ func isRedisError(err error) bool { func isBadConn(err error, allowTimeout bool, addr string) bool { switch err { - case nil: - return false - case context.Canceled, context.DeadlineExceeded: - return true + case nil: + return false + case context.Canceled, context.DeadlineExceeded: + return true + case pool.ErrConnUnusableTimeout: + return true } if isRedisError(err) { diff --git a/internal/auth/streaming/conn_reauth_credentials_listener.go b/internal/auth/streaming/conn_reauth_credentials_listener.go new file mode 100644 index 000000000..d0ac8a841 --- /dev/null +++ b/internal/auth/streaming/conn_reauth_credentials_listener.go @@ -0,0 +1,68 @@ +package streaming + +import ( + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface. +// It is used to re-authenticate the credentials when they are updated. +// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks. +// It contains: +// - reAuth: a function that takes the new credentials and returns an error if any. +// - onErr: a function that takes an error and handles it. +// - conn: the connection to re-authenticate. +type ConnReAuthCredentialsListener struct { + // reAuth is called when the credentials are updated. + reAuth func(conn *pool.Conn, credentials auth.Credentials) error + // onErr is called when an error occurs. + onErr func(conn *pool.Conn, err error) + // conn is the connection to re-authenticate. + conn *pool.Conn + + manager *Manager +} + +// OnNext is called when the credentials are updated. +// It calls the reAuth function with the new credentials. +// If the reAuth function returns an error, it calls the onErr function with the error. +func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { + if c.conn.IsClosed() { + return + } + + if c.reAuth == nil { + return + } + + // Always use async reauth to avoid complex pool semaphore issues + // The synchronous path can cause deadlocks in the pool's semaphore mechanism + // when called from the Subscribe goroutine, especially with small pool sizes. + // The connection pool hook will re-authenticate the connection when it is + // returned to the pool in a clean, idle state. + c.manager.MarkForReAuth(c.conn, func(err error) { + if err != nil { + c.OnError(err) + return + } + err = c.reAuth(c.conn, credentials) + if err != nil { + c.OnError(err) + return + } + }) + +} + +// OnError is called when an error occurs. +// It can be called from both the credentials provider and the reAuth function. +func (c *ConnReAuthCredentialsListener) OnError(err error) { + if c.onErr == nil { + return + } + + c.onErr(c.conn, err) +} + +// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. +var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil) diff --git a/internal/auth/streaming/cred_listeners.go b/internal/auth/streaming/cred_listeners.go new file mode 100644 index 000000000..de6d6a6b3 --- /dev/null +++ b/internal/auth/streaming/cred_listeners.go @@ -0,0 +1,44 @@ +package streaming + +import ( + "sync" + + "github.com/redis/go-redis/v9/auth" +) + +type CredentialsListeners struct { + // connid -> listener + listeners map[uint64]auth.CredentialsListener + lock sync.RWMutex +} + +func NewCredentialsListeners() *CredentialsListeners { + return &CredentialsListeners{ + listeners: make(map[uint64]auth.CredentialsListener), + } +} + +func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) { + c.lock.Lock() + defer c.lock.Unlock() + if c.listeners == nil { + c.listeners = make(map[uint64]auth.CredentialsListener) + } + c.listeners[connID] = listener +} + +func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + if len(c.listeners) == 0 { + return nil, false + } + listener, ok := c.listeners[connID] + return listener, ok +} + +func (c *CredentialsListeners) Remove(connID uint64) { + c.lock.Lock() + defer c.lock.Unlock() + delete(c.listeners, connID) +} diff --git a/internal/auth/streaming/manager.go b/internal/auth/streaming/manager.go new file mode 100644 index 000000000..3f529d151 --- /dev/null +++ b/internal/auth/streaming/manager.go @@ -0,0 +1,56 @@ +package streaming + +import ( + "errors" + "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +type Manager struct { + credentialsListeners *CredentialsListeners + pool pool.Pooler + poolHookRef *ReAuthPoolHook +} + +func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager { + return &Manager{ + pool: pl, + poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout), + credentialsListeners: NewCredentialsListeners(), + } +} + +func (m *Manager) PoolHook() pool.PoolHook { + return m.poolHookRef +} + +func (m *Manager) Listener( + poolCn *pool.Conn, + reAuth func(*pool.Conn, auth.Credentials) error, + onErr func(*pool.Conn, error), +) (auth.CredentialsListener, error) { + if poolCn == nil { + return nil, errors.New("poolCn cannot be nil") + } + connID := poolCn.GetID() + listener, ok := m.credentialsListeners.Get(connID) + if !ok || listener == nil { + newCredListener := &ConnReAuthCredentialsListener{ + conn: poolCn, + reAuth: reAuth, + onErr: onErr, + manager: m, + } + + m.credentialsListeners.Add(connID, newCredListener) + listener = newCredListener + } + return listener, nil +} + +func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { + connID := poolCn.GetID() + m.poolHookRef.MarkForReAuth(connID, reAuthFn) +} diff --git a/internal/auth/streaming/manager_test.go b/internal/auth/streaming/manager_test.go new file mode 100644 index 000000000..e4ff813ed --- /dev/null +++ b/internal/auth/streaming/manager_test.go @@ -0,0 +1,101 @@ +package streaming + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// Test that Listener returns the newly created listener, not nil +func TestManager_Listener_ReturnsNewListener(t *testing.T) { + // Create a mock pool + mockPool := &mockPooler{} + + // Create manager + manager := NewManager(mockPool, time.Second) + + // Create a mock connection + conn := &pool.Conn{} + + // Mock functions + reAuth := func(cn *pool.Conn, creds auth.Credentials) error { + return nil + } + + onErr := func(cn *pool.Conn, err error) { + } + + // Get listener - this should create a new one + listener, err := manager.Listener(conn, reAuth, onErr) + + // Verify no error + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Verify listener is not nil (this was the bug!) + if listener == nil { + t.Fatal("Expected listener to be non-nil, but got nil") + } + + // Verify it's the correct type + if _, ok := listener.(*ConnReAuthCredentialsListener); !ok { + t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener) + } + + // Get the same listener again - should return the existing one + listener2, err := manager.Listener(conn, reAuth, onErr) + if err != nil { + t.Fatalf("Expected no error on second call, got: %v", err) + } + + if listener2 == nil { + t.Fatal("Expected listener2 to be non-nil") + } + + // Should be the same instance + if listener != listener2 { + t.Error("Expected to get the same listener instance on second call") + } +} + +// Test that Listener returns error when conn is nil +func TestManager_Listener_NilConn(t *testing.T) { + mockPool := &mockPooler{} + manager := NewManager(mockPool, time.Second) + + listener, err := manager.Listener(nil, nil, nil) + + if err == nil { + t.Fatal("Expected error when conn is nil, got nil") + } + + if listener != nil { + t.Error("Expected listener to be nil when error occurs") + } + + expectedErr := "poolCn cannot be nil" + if err.Error() != expectedErr { + t.Errorf("Expected error message %q, got %q", expectedErr, err.Error()) + } +} + +// Mock pooler for testing +type mockPooler struct{} + +func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil } +func (m *mockPooler) CloseConn(*pool.Conn) error { return nil } +func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil } +func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {} +func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {} +func (m *mockPooler) Len() int { return 0 } +func (m *mockPooler) IdleLen() int { return 0 } +func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} } +func (m *mockPooler) Size() int { return 10 } +func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {} +func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {} +func (m *mockPooler) Close() error { return nil } + diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go new file mode 100644 index 000000000..39bf83522 --- /dev/null +++ b/internal/auth/streaming/pool_hook.go @@ -0,0 +1,145 @@ +package streaming + +import ( + "context" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +type ReAuthPoolHook struct { + // conn id -> func() reauth func with error handling + shouldReAuth map[uint64]func(error) + shouldReAuthLock sync.RWMutex + workers chan struct{} + reAuthTimeout time.Duration + // conn id -> bool + scheduledReAuth map[uint64]bool + scheduledLock sync.RWMutex +} + +func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { + workers := make(chan struct{}, poolSize) + // Initialize the workers channel with tokens (semaphore pattern) + for i := 0; i < poolSize; i++ { + workers <- struct{}{} + } + + return &ReAuthPoolHook{ + shouldReAuth: make(map[uint64]func(error)), + scheduledReAuth: make(map[uint64]bool), + workers: workers, + reAuthTimeout: reAuthTimeout, + } + +} + +func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { + r.shouldReAuthLock.Lock() + defer r.shouldReAuthLock.Unlock() + r.shouldReAuth[connID] = reAuthFn +} + +func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) { + r.shouldReAuthLock.Lock() + defer r.shouldReAuthLock.Unlock() + delete(r.shouldReAuth, connID) +} + +func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { + r.shouldReAuthLock.RLock() + _, ok := r.shouldReAuth[conn.GetID()] + r.shouldReAuthLock.RUnlock() + // This connection was marked for reauth while in the pool, + // reject the connection + if ok { + // simply reject the connection, it will be re-authenticated in OnPut + return false, nil + } + r.scheduledLock.RLock() + hasScheduled, ok := r.scheduledReAuth[conn.GetID()] + r.scheduledLock.RUnlock() + // has scheduled reauth, reject the connection + if ok && hasScheduled { + // simply reject the connection, it will be re-authenticated in OnPut + return false, nil + } + return true, nil +} + +func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { + // Check if reauth is needed and get the function with proper locking + r.shouldReAuthLock.RLock() + reAuthFn, ok := r.shouldReAuth[conn.GetID()] + r.shouldReAuthLock.RUnlock() + + if ok { + r.scheduledLock.Lock() + r.scheduledReAuth[conn.GetID()] = true + r.scheduledLock.Unlock() + // Clear the mark immediately to prevent duplicate reauth attempts + r.ClearReAuthMark(conn.GetID()) + go func() { + <-r.workers + defer func() { + r.scheduledLock.Lock() + delete(r.scheduledReAuth, conn.GetID()) + r.scheduledLock.Unlock() + r.workers <- struct{}{} + }() + + var err error + timeout := time.After(r.reAuthTimeout) + + // Try to acquire the connection + // We need to ensure the connection is both Usable and not Used + // to prevent data races with concurrent operations + acquired := false + for !acquired { + select { + case <-timeout: + // Timeout occurred, cannot acquire connection + err = pool.ErrConnUnusableTimeout + reAuthFn(err) + return + default: + // Try to acquire: set Usable=false, then check Used + if conn.Usable.CompareAndSwap(true, false) { + if !conn.Used.Load() { + acquired = true + } else { + // Release Usable and retry + conn.Usable.Store(true) + time.Sleep(time.Millisecond) + } + } else { + time.Sleep(time.Millisecond) + } + } + } + + // Successfully acquired the connection, perform reauth + reAuthFn(nil) + + // Release the connection + conn.Usable.Store(true) + }() + } + + // the reauth will happen in background, as far as the pool is concerned: + // pool the connection, don't remove it, no error + return true, false, nil +} + +func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) { + r.scheduledLock.Lock() + delete(r.scheduledReAuth, conn.GetID()) + r.scheduledLock.Unlock() + r.shouldReAuthLock.Lock() + delete(r.shouldReAuth, conn.GetID()) + r.shouldReAuthLock.Unlock() + r.ClearReAuthMark(conn.GetID()) +} + +var _ pool.PoolHook = (*ReAuthPoolHook)(nil) diff --git a/internal/pool/buffer_size_test.go b/internal/pool/buffer_size_test.go index 71223d708..bffe495cd 100644 --- a/internal/pool/buffer_size_test.go +++ b/internal/pool/buffer_size_test.go @@ -3,7 +3,6 @@ package pool_test import ( "bufio" "context" - "net" "unsafe" . "github.com/bsm/ginkgo/v2" @@ -124,20 +123,26 @@ var _ = Describe("Buffer Size Configuration", func() { }) // Helper functions to extract buffer sizes using unsafe pointers +// The struct layout must match pool.Conn exactly to avoid checkptr violations. +// checkptr is Go's pointer safety checker, which ensures that unsafe pointer +// conversions are valid. If the struct layouts do not match exactly, this can +// cause runtime panics or incorrect memory access due to invalid pointer dereferencing. func getWriterBufSizeUnsafe(cn *pool.Conn) int { cnPtr := (*struct { - usedAt int64 - netConn net.Conn - rd *proto.Reader - bw *bufio.Writer - wr *proto.Writer - // ... other fields + id uint64 // First field in pool.Conn + usedAt int64 // Second field (atomic) + netConnAtomic interface{} // atomic.Value (interface{} has same size) + rd *proto.Reader + bw *bufio.Writer + wr *proto.Writer + // We only need fields up to bw, so we can stop here })(unsafe.Pointer(cn)) if cnPtr.bw == nil { return -1 } + // bufio.Writer internal structure bwPtr := (*struct { err error buf []byte @@ -150,18 +155,20 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int { func getReaderBufSizeUnsafe(cn *pool.Conn) int { cnPtr := (*struct { - usedAt int64 - netConn net.Conn - rd *proto.Reader - bw *bufio.Writer - wr *proto.Writer - // ... other fields + id uint64 // First field in pool.Conn + usedAt int64 // Second field (atomic) + netConnAtomic interface{} // atomic.Value (interface{} has same size) + rd *proto.Reader + bw *bufio.Writer + wr *proto.Writer + // We only need fields up to rd, so we can stop here })(unsafe.Pointer(cn)) if cnPtr.rd == nil { return -1 } + // proto.Reader internal structure rdPtr := (*struct { rd *bufio.Reader })(unsafe.Pointer(cnPtr.rd)) @@ -170,6 +177,7 @@ func getReaderBufSizeUnsafe(cn *pool.Conn) int { return -1 } + // bufio.Reader internal structure bufReaderPtr := (*struct { buf []byte rd interface{} diff --git a/internal/pool/conn.go b/internal/pool/conn.go index e47805464..903fbedab 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -40,6 +40,9 @@ func generateConnID() uint64 { } type Conn struct { + // Connection identifier for unique tracking + id uint64 // Unique numeric identifier for this connection + usedAt int64 // atomic // Lock-free netConn access using atomic.Value @@ -54,7 +57,34 @@ type Conn struct { // Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe readerMu sync.RWMutex - Inited atomic.Bool + // Design note: + // Why have both Usable and Used? + // _Usable_ is used to mark a connection as safe for use by clients, the connection can still + // be in the pool but not Usable at the moment (e.g. handoff in progress). + // _Used_ is used to mark a connection as used when a command is going to be processed on that connection. + // this is going to happen once the connection is picked from the pool. + // + // If a background operation needs to use the connection, it will mark it as Not Usable and only use it when it + // is not in use. That way, the connection won't be used to send multiple commands at the same time and + // potentially corrupt the command stream. + + // Usable flag to mark connection as safe for use + // It is false before initialization and after a handoff is marked + // It will be false during other background operations like re-authentication + Usable atomic.Bool + + // Used flag to mark connection as used when a command is going to be + // processed on that connection. This is used to prevent a race condition with + // background operations that may execute commands, like re-authentication. + Used atomic.Bool + + // Inited flag to mark connection as initialized, this is almost the same as usable + // but it is used to make sure we don't initialize a network connection twice + // On handoff, the network connection is replaced, but the Conn struct is reused + // this flag will be set to false when the network connection is replaced and + // set to true after the new network connection is initialized + Inited atomic.Bool + pooled bool pubsub bool closed atomic.Bool @@ -75,11 +105,7 @@ type Conn struct { // Connection initialization function for reconnections initConnFunc func(context.Context, *Conn) error - // Connection identifier for unique tracking - id uint64 // Unique numeric identifier for this connection - // Handoff state - using atomic operations for lock-free access - usableAtomic atomic.Bool // Connection usability state handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts // Atomic handoff state to prevent race conditions @@ -116,7 +142,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) // Initialize atomic state - cn.usableAtomic.Store(false) // false initially, set to true after initialization + cn.Usable.Store(false) // false initially, set to true after initialization cn.handoffRetriesAtomic.Store(0) // 0 initially // Initialize handoff state atomically @@ -162,12 +188,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) { // isUsable returns true if the connection is safe to use (lock-free). func (cn *Conn) isUsable() bool { - return cn.usableAtomic.Load() + return cn.Usable.Load() } // setUsable sets the usable flag atomically (lock-free). func (cn *Conn) setUsable(usable bool) { - cn.usableAtomic.Store(usable) + cn.Usable.Store(usable) } // getHandoffState returns the current handoff state atomically (lock-free). @@ -455,9 +481,27 @@ func (cn *Conn) MarkQueuedForHandoff() error { const maxRetries = 50 const baseDelay = time.Microsecond + connAcquired := false for attempt := 0; attempt < maxRetries; attempt++ { - currentState := cn.getHandoffState() + // If CAS failed, add exponential backoff to reduce contention + // the delay will be 1, 2, 4... up to 512 microseconds + // Moving this to the top of the loop to avoid "continue" without delay + if attempt > 0 && attempt < maxRetries-1 { + delay := baseDelay * time.Duration(1<= getAttempts { internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) @@ -454,17 +466,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() - if hookManager != nil { - if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil { + acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) + if err != nil { internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) - // Failed to process connection, discard it _ = p.CloseConn(cn) continue } + if !acceptConn { + internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + p.Put(ctx, cn) + cn = nil + continue + } } atomic.AddUint32(&p.stats.Hits, 1) @@ -480,14 +494,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() - if hookManager != nil { - if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil { + acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true) + // both errors and accept=false mean a hook rejected the connection + // this should not happen with a new connection, but we handle it gracefully + if err != nil || !acceptConn { // Failed to process connection, discard it - internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err) + internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) _ = p.CloseConn(newcn) return nil, err } @@ -568,8 +581,10 @@ func (p *ConnPool) popIdle() (*Conn, error) { attempts++ if cn.IsUsable() { - p.idleConnsLen.Add(-1) - break + if cn.Used.CompareAndSwap(false, true) { + p.idleConnsLen.Add(-1) + break + } } // Connection is not usable, put it back in the pool @@ -664,6 +679,12 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { shouldCloseConn = true } + // Mark connection as not used only + // if it's not being closed + if !shouldCloseConn { + cn.Used.Store(false) + } + p.freeTurn() if shouldCloseConn { @@ -671,7 +692,15 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { } } -func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) { +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + hookManager.ProcessOnRemove(ctx, cn, reason) + } + p.removeConnWithLock(cn) p.freeTurn() @@ -733,6 +762,10 @@ func (p *ConnPool) IdleLen() int { return int(n) } +func (p *ConnPool) Size() int { + return int(p.cfg.PoolSize) +} + func (p *ConnPool) Stats() *Stats { return &Stats{ Hits: atomic.LoadUint32(&p.stats.Hits), diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 136d6f2dd..bce7cf0f0 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -2,6 +2,7 @@ package pool import ( "context" + "time" ) type SingleConnPool struct { @@ -31,12 +32,26 @@ func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { if p.stickyErr != nil { return nil, p.stickyErr } + if p.cn == nil { + return nil, ErrClosed + } + p.cn.Used.Store(true) + p.cn.SetUsedAt(time.Now()) return p.cn, nil } -func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} +func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) { + if p.cn == nil { + return + } + if p.cn != cn { + return + } + p.cn.Used.Store(false) +} func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + cn.Used.Store(false) p.cn = nil p.stickyErr = reason } @@ -55,6 +70,8 @@ func (p *SingleConnPool) IdleLen() int { return 0 } +func (p *SingleConnPool) Size() int { return 1 } + func (p *SingleConnPool) Stats() *Stats { return &Stats{} } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index dc4266a4f..335e74129 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -196,6 +196,8 @@ func (p *StickyConnPool) IdleLen() int { return len(p.ch) } +func (p *StickyConnPool) Size() int { return 1 } + func (p *StickyConnPool) Stats() *Stats { return &Stats{} } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index ef1ed5f9b..6aa6dc091 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -497,9 +497,14 @@ func TestDialerRetryConfiguration(t *testing.T) { } // Should have attempted 5 times (default DialerRetries = 5) + // Note: There may be one additional attempt from tryDial() goroutine + // which is launched when dialErrorsNum reaches PoolSize finalAttempts := atomic.LoadInt64(&attempts) - if finalAttempts != 5 { - t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts) + if finalAttempts < 5 { + t.Errorf("Expected at least 5 dial attempts (default), got %d", finalAttempts) + } + if finalAttempts > 6 { + t.Errorf("Expected around 5 dial attempts, got %d (too many)", finalAttempts) } }) } diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go index 73ee4b3ec..ed87d1bbc 100644 --- a/internal/pool/pubsub.go +++ b/internal/pool/pubsub.go @@ -24,6 +24,8 @@ type PubSubPool struct { stats PubSubStats } +// PubSubPool implements a pool for PubSub connections. +// It intentionally does not implement the Pooler interface func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { return &PubSubPool{ opt: opt, diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 61dc1e171..dce984c7c 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -378,8 +378,12 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c } // performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration) -func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) { - +func (hwm *handoffWorkerManager) performHandoffInternal( + ctx context.Context, + conn *pool.Conn, + newEndpoint string, + connID uint64, +) (shouldRetry bool, err error) { retries := conn.IncrementAndGetHandoffRetries(1) internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) maxRetries := 3 // Default fallback @@ -438,9 +442,14 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con } }() + // Clear handoff state will: + // - set the connection as usable again + // - clear the handoff state (shouldHandoff, endpoint, seqID) + // - reset the handoff retries to 0 conn.ClearHandoffState() internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) + // successfully completed the handoff, no retry needed and no error return false, nil } diff --git a/maintnotifications/pool_hook.go b/maintnotifications/pool_hook.go index 695c3a648..9fd24b4a7 100644 --- a/maintnotifications/pool_hook.go +++ b/maintnotifications/pool_hook.go @@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() { } // OnGet is called when a connection is retrieved from the pool -func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error { +func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is // in a handoff state at the moment. // Check if connection is usable (not in a handoff state) // Should not happen since the pool will not return a connection that is not usable. if !conn.IsUsable() { - return ErrConnectionMarkedForHandoff + return false, ErrConnectionMarkedForHandoff } // Check if connection is marked for handoff, which means it will be queued for handoff on put. if conn.ShouldHandoff() { - return ErrConnectionMarkedForHandoff + return false, ErrConnectionMarkedForHandoff } - return nil + return true, nil } // OnPut is called when a connection is returned to the pool @@ -174,6 +174,10 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool return true, false, nil } +func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) { + // Not used +} + // Shutdown gracefully shuts down the processor, waiting for workers to complete func (ph *PoolHook) Shutdown(ctx context.Context) error { return ph.workerManager.shutdownWorkers(ctx) diff --git a/maintnotifications/pool_hook_test.go b/maintnotifications/pool_hook_test.go index c689179d7..51e73c3ec 100644 --- a/maintnotifications/pool_hook_test.go +++ b/maintnotifications/pool_hook_test.go @@ -92,6 +92,10 @@ func (mp *mockPool) Stats() *pool.Stats { return &pool.Stats{} } +func (mp *mockPool) Size() int { + return 0 +} + func (mp *mockPool) AddPoolHook(hook pool.PoolHook) { // Mock implementation - do nothing } @@ -356,10 +360,13 @@ func TestConnectionHook(t *testing.T) { conn := createMockPoolConnection() ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should not error for normal connection: %v", err) } + if !acceptCon { + t.Error("Connection should be accepted for normal connection") + } }) t.Run("OnGetWithPendingHandoff", func(t *testing.T) { @@ -381,10 +388,13 @@ func TestConnectionHook(t *testing.T) { conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != ErrConnectionMarkedForHandoff { t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } + if acceptCon { + t.Error("Connection should not be accepted when marked for handoff") + } // Clean up processor.GetPendingMap().Delete(conn) @@ -412,10 +422,13 @@ func TestConnectionHook(t *testing.T) { // Test OnGet with pending handoff ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != ErrConnectionMarkedForHandoff { t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") } + if acceptCon { + t.Error("Should not accept connection with pending handoff") + } // Test removing from pending map and clearing handoff state processor.GetPendingMap().Delete(conn) @@ -428,10 +441,13 @@ func TestConnectionHook(t *testing.T) { conn.SetUsable(true) // Make connection usable again // Test OnGet without pending handoff - err = processor.OnGet(ctx, conn, false) + acceptCon, err = processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("Should not return error for non-pending connection: %v", err) } + if !acceptCon { + t.Error("Should accept connection without pending handoff") + } }) t.Run("EventDrivenQueueOptimization", func(t *testing.T) { @@ -624,11 +640,15 @@ func TestConnectionHook(t *testing.T) { } // OnGet should succeed for usable connection - err := processor.OnGet(ctx, conn, false) + acceptConn, err := processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should succeed for usable connection: %v", err) } + if !acceptConn { + t.Error("Connection should be accepted when usable") + } + // Mark connection for handoff if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { t.Fatalf("Failed to mark connection for handoff: %v", err) @@ -648,13 +668,17 @@ func TestConnectionHook(t *testing.T) { } // OnGet should fail for connection marked for handoff - err = processor.OnGet(ctx, conn, false) + acceptConn, err = processor.OnGet(ctx, conn, false) if err == nil { t.Error("OnGet should fail for connection marked for handoff") } + if err != ErrConnectionMarkedForHandoff { t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } + if acceptConn { + t.Error("Connection should not be accepted when marked for handoff") + } // Process the connection to trigger handoff shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) @@ -674,11 +698,15 @@ func TestConnectionHook(t *testing.T) { } // OnGet should succeed again - err = processor.OnGet(ctx, conn, false) + acceptConn, err = processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should succeed after handoff completion: %v", err) } + if !acceptConn { + t.Error("Connection should be accepted after handoff completion") + } + t.Logf("Usable flag behavior test completed successfully") }) diff --git a/pubsub.go b/pubsub.go index 5e02b0bd2..959a5c45b 100644 --- a/pubsub.go +++ b/pubsub.go @@ -465,7 +465,6 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int } // Don't hold the lock to allow subscriptions and pings. - cn, err := c.connWithLock(ctx) if err != nil { return nil, err diff --git a/redis.go b/redis.go index b308263e2..7d4c903ce 100644 --- a/redis.go +++ b/redis.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/auth/streaming" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" @@ -224,6 +225,9 @@ type baseClient struct { // Maintenance notifications manager maintNotificationsManager *maintnotifications.Manager maintNotificationsManagerLock sync.RWMutex + + // streamingCredentialsManager is used to manage streaming credentials + streamingCredentialsManager *streaming.Manager } func (c *baseClient) clone() *baseClient { @@ -232,11 +236,12 @@ func (c *baseClient) clone() *baseClient { c.maintNotificationsManagerLock.RUnlock() clone := &baseClient{ - opt: c.opt, - connPool: c.connPool, - onClose: c.onClose, - pushProcessor: c.pushProcessor, - maintNotificationsManager: maintNotificationsManager, + opt: c.opt, + connPool: c.connPool, + onClose: c.onClose, + pushProcessor: c.pushProcessor, + maintNotificationsManager: maintNotificationsManager, + streamingCredentialsManager: c.streamingCredentialsManager, } return clone } @@ -296,32 +301,30 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener { - return auth.NewReAuthCredentialsListener( - c.reAuthConnection(poolCn), - c.onAuthenticationErr(poolCn), - ) -} - -func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error { - return func(credentials auth.Credentials) error { +func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error { + return func(poolCn *pool.Conn, credentials auth.Credentials) error { var err error username, password := credentials.BasicAuth() + + // Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter ctx := context.Background() + connPool := pool.NewSingleConnPool(c.connPool, poolCn) - // hooksMixin are intentionally empty here - cn := newConn(c.opt, connPool, nil) + + // Pass hooks so that reauth commands are recorded/traced + cn := newConn(c.opt, connPool, &c.hooksMixin) if username != "" { err = cn.AuthACL(ctx, username, password).Err() } else { err = cn.Auth(ctx, password).Err() } + return err } } -func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) { - return func(err error) { +func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { + return func(poolCn *pool.Conn, err error) { if err != nil { if isBadConn(err, false, c.opt.Addr) { // Close the connection to force a reconnection. @@ -372,13 +375,24 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { + credListener, err := c.streamingCredentialsManager.Listener( + cn, + c.reAuthConnection(), + c.onAuthenticationErr(), + ) + if err != nil { + return fmt.Errorf("failed to create credentials listener: %w", err) + } + credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. - Subscribe(c.newReAuthCredentialsListener(cn)) + Subscribe(credListener) if err != nil { return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) } + c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider) cn.SetOnClose(unsubscribeFromCredentialsProvider) + username, password = credentials.BasicAuth() } else if c.opt.CredentialsProviderContext != nil { username, password, err = c.opt.CredentialsProviderContext(ctx) @@ -496,7 +510,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + // mark the connection as usable and inited + // once returned to the pool as idle, this connection can be used by other clients cn.SetUsable(true) + cn.Used.Store(false) cn.Inited.Store(true) // Set the connection initialization function for potential reconnections @@ -952,6 +969,11 @@ func NewClient(opt *Options) *Client { panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) } + if opt.StreamingCredentialsProvider != nil { + c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout) + c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook()) + } + // Initialize maintnotifications first if enabled and protocol is RESP3 if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 { err := c.enableMaintNotificationsUpgrades() diff --git a/redis_test.go b/redis_test.go index 27b69ed14..0906d420b 100644 --- a/redis_test.go +++ b/redis_test.go @@ -854,24 +854,34 @@ var _ = Describe("Credentials Provider Priority", func() { credentials: initialCreds, updates: updatesChan, }, + PoolSize: 1, // Force single connection to ensure reauth is tested } client = redis.NewClient(opt) client.AddHook(recorder.Hook()) // wrongpass Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + time.Sleep(10 * time.Millisecond) Expect(recorder.Contains("AUTH initial_user")).To(BeTrue()) // Update credentials opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds - // wrongpass - Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) - Expect(recorder.Contains("AUTH updated_user")).To(BeTrue()) + + // Wait for reauth to complete and verify updated credentials are used + // We need to keep trying Ping until we see the updated AUTH command + // because the reauth happens asynchronously + Eventually(func() bool { + // wrongpass + _ = client.Ping(context.Background()).Err() + return recorder.Contains("AUTH updated_user") + }, "1s", "50ms").Should(BeTrue()) + close(updatesChan) }) }) type mockStreamingProvider struct { + mu sync.RWMutex credentials auth.Credentials err error updates chan auth.Credentials @@ -882,21 +892,50 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au return nil, nil, m.err } + if listener == nil { + return nil, nil, errors.New("listener cannot be nil") + } + + // Create a done channel to stop the goroutine + done := make(chan struct{}) + // Start goroutine to handle updates go func() { - for creds := range m.updates { - m.credentials = creds - listener.OnNext(creds) + defer func() { + if r := recover(); r != nil { + // this is just a mock: + // allow panics to be caught without crashing + } + }() + + for { + select { + case <-done: + return + case creds, ok := <-m.updates: + if !ok { + return + } + m.mu.Lock() + m.credentials = creds + m.mu.Unlock() + listener.OnNext(creds) + } } }() - return m.credentials, func() (err error) { + m.mu.RLock() + currentCreds := m.credentials + m.mu.RUnlock() + + return currentCreds, func() (err error) { defer func() { if r := recover(); r != nil { // this is just a mock: // allow multiple closes from multiple listeners } }() + close(done) return }, nil } diff --git a/sentinel_test.go b/sentinel_test.go index bfeb28161..f332822f5 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -410,7 +410,9 @@ var _ = Describe("SentinelAclAuth", func() { }) }) -func TestParseFailoverURL(t *testing.T) { +// renaming from TestParseFailoverURL to TestParseSentinelURL +// to be easier to find Failed tests in the test output +func TestParseSentinelURL(t *testing.T) { cases := []struct { url string o *redis.FailoverOptions