Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5fe0bfa
fix(pool): wip, pool reauth should not interfere with handoff
ndyakov Oct 14, 2025
d39da69
fix credListeners map
ndyakov Oct 14, 2025
8a629fb
fix race in tests
ndyakov Oct 14, 2025
6c54ab5
Merge branch 'master' into ndyakov/pool-reauth
ndyakov Oct 14, 2025
90bfdb3
better conn usable timeout
ndyakov Oct 14, 2025
07283ec
add design decision comment
ndyakov Oct 15, 2025
1bbf2e6
few small improvements
ndyakov Oct 15, 2025
1428068
update marked as queued
ndyakov Oct 15, 2025
e7dc339
add Used to clarify the state of the conn
ndyakov Oct 15, 2025
77c0c73
rename test
ndyakov Oct 15, 2025
011ef96
fix(test): fix flaky test
ndyakov Oct 15, 2025
e03396e
lock inside the listeners collection
ndyakov Oct 15, 2025
391b6c5
address pr comments
ndyakov Oct 15, 2025
6ad9a67
Update internal/auth/cred_listeners.go
ndyakov Oct 15, 2025
0c4f8fb
Update internal/pool/buffer_size_test.go
ndyakov Oct 15, 2025
acb55d8
wip refactor entraid
ndyakov Oct 16, 2025
d74671b
fix maintnotif pool hook
ndyakov Oct 17, 2025
0e10cd7
fix mocks
ndyakov Oct 17, 2025
afba8c2
fix nil listener
ndyakov Oct 17, 2025
4bc6d33
sync and async reauth based on conn lifecycle
ndyakov Oct 17, 2025
3020e3a
be able to reject connection OnGet
ndyakov Oct 17, 2025
f886775
pass hooks so the tests can observe reauth
ndyakov Oct 17, 2025
72cf74a
give some time for the background to execute commands
ndyakov Oct 17, 2025
c715185
fix tests
ndyakov Oct 17, 2025
f14095b
only async reauth
ndyakov Oct 17, 2025
e94cc9f
Merge branch 'master' into ndyakov/pool-reauth
ndyakov Oct 17, 2025
19f4080
Update internal/pool/pool.go
ndyakov Oct 17, 2025
4049d5e
Update internal/auth/streaming/pool_hook.go
ndyakov Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auth/reauth_credentials_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on
}

// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
10 changes: 6 additions & 4 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ func isRedisError(err error) bool {

func isBadConn(err error, allowTimeout bool, addr string) bool {
switch err {
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case pool.ErrConnUnusableTimeout:
return true
}

if isRedisError(err) {
Expand Down
68 changes: 68 additions & 0 deletions internal/auth/streaming/conn_reauth_credentials_listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package streaming

import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)

// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
// It is used to re-authenticate the credentials when they are updated.
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
// It contains:
// - reAuth: a function that takes the new credentials and returns an error if any.
// - onErr: a function that takes an error and handles it.
// - conn: the connection to re-authenticate.
type ConnReAuthCredentialsListener struct {
// reAuth is called when the credentials are updated.
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
// onErr is called when an error occurs.
onErr func(conn *pool.Conn, err error)
// conn is the connection to re-authenticate.
conn *pool.Conn

manager *Manager
}

// OnNext is called when the credentials are updated.
// It calls the reAuth function with the new credentials.
// If the reAuth function returns an error, it calls the onErr function with the error.
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
if c.conn.IsClosed() {
return
}

if c.reAuth == nil {
return
}

// Always use async reauth to avoid complex pool semaphore issues
// The synchronous path can cause deadlocks in the pool's semaphore mechanism
// when called from the Subscribe goroutine, especially with small pool sizes.
// The connection pool hook will re-authenticate the connection when it is
// returned to the pool in a clean, idle state.
c.manager.MarkForReAuth(c.conn, func(err error) {
if err != nil {
c.OnError(err)
return
}
err = c.reAuth(c.conn, credentials)
if err != nil {
c.OnError(err)
return
}
})

}

// OnError is called when an error occurs.
// It can be called from both the credentials provider and the reAuth function.
func (c *ConnReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}

c.onErr(c.conn, err)
}

// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
44 changes: 44 additions & 0 deletions internal/auth/streaming/cred_listeners.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package streaming

import (
"sync"

"github.com/redis/go-redis/v9/auth"
)

type CredentialsListeners struct {
// connid -> listener
listeners map[uint64]auth.CredentialsListener
lock sync.RWMutex
}

func NewCredentialsListeners() *CredentialsListeners {
return &CredentialsListeners{
listeners: make(map[uint64]auth.CredentialsListener),
}
}

func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
c.lock.Lock()
defer c.lock.Unlock()
if c.listeners == nil {
c.listeners = make(map[uint64]auth.CredentialsListener)
}
c.listeners[connID] = listener
}

func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
if len(c.listeners) == 0 {
return nil, false
}
listener, ok := c.listeners[connID]
return listener, ok
}

func (c *CredentialsListeners) Remove(connID uint64) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.listeners, connID)
}
56 changes: 56 additions & 0 deletions internal/auth/streaming/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package streaming

import (
"errors"
"time"

"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)

type Manager struct {
credentialsListeners *CredentialsListeners
pool pool.Pooler
poolHookRef *ReAuthPoolHook
}

func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
return &Manager{
pool: pl,
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
credentialsListeners: NewCredentialsListeners(),
}
}

func (m *Manager) PoolHook() pool.PoolHook {
return m.poolHookRef
}

func (m *Manager) Listener(
poolCn *pool.Conn,
reAuth func(*pool.Conn, auth.Credentials) error,
onErr func(*pool.Conn, error),
) (auth.CredentialsListener, error) {
if poolCn == nil {
return nil, errors.New("poolCn cannot be nil")
}
connID := poolCn.GetID()
listener, ok := m.credentialsListeners.Get(connID)
if !ok || listener == nil {
newCredListener := &ConnReAuthCredentialsListener{
conn: poolCn,
reAuth: reAuth,
onErr: onErr,
manager: m,
}

m.credentialsListeners.Add(connID, newCredListener)
listener = newCredListener
}
return listener, nil
}

func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
connID := poolCn.GetID()
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
}
101 changes: 101 additions & 0 deletions internal/auth/streaming/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package streaming

import (
"context"
"testing"
"time"

"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)

// Test that Listener returns the newly created listener, not nil
func TestManager_Listener_ReturnsNewListener(t *testing.T) {
// Create a mock pool
mockPool := &mockPooler{}

// Create manager
manager := NewManager(mockPool, time.Second)

// Create a mock connection
conn := &pool.Conn{}

// Mock functions
reAuth := func(cn *pool.Conn, creds auth.Credentials) error {
return nil
}

onErr := func(cn *pool.Conn, err error) {
}

// Get listener - this should create a new one
listener, err := manager.Listener(conn, reAuth, onErr)

// Verify no error
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}

// Verify listener is not nil (this was the bug!)
if listener == nil {
t.Fatal("Expected listener to be non-nil, but got nil")
}

// Verify it's the correct type
if _, ok := listener.(*ConnReAuthCredentialsListener); !ok {
t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener)
}

// Get the same listener again - should return the existing one
listener2, err := manager.Listener(conn, reAuth, onErr)
if err != nil {
t.Fatalf("Expected no error on second call, got: %v", err)
}

if listener2 == nil {
t.Fatal("Expected listener2 to be non-nil")
}

// Should be the same instance
if listener != listener2 {
t.Error("Expected to get the same listener instance on second call")
}
}

// Test that Listener returns error when conn is nil
func TestManager_Listener_NilConn(t *testing.T) {
mockPool := &mockPooler{}
manager := NewManager(mockPool, time.Second)

listener, err := manager.Listener(nil, nil, nil)

if err == nil {
t.Fatal("Expected error when conn is nil, got nil")
}

if listener != nil {
t.Error("Expected listener to be nil when error occurs")
}

expectedErr := "poolCn cannot be nil"
if err.Error() != expectedErr {
t.Errorf("Expected error message %q, got %q", expectedErr, err.Error())
}
}

// Mock pooler for testing
type mockPooler struct{}

func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil }
func (m *mockPooler) CloseConn(*pool.Conn) error { return nil }
func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil }
func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {}
func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {}
func (m *mockPooler) Len() int { return 0 }
func (m *mockPooler) IdleLen() int { return 0 }
func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} }
func (m *mockPooler) Size() int { return 10 }
func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {}
func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {}
func (m *mockPooler) Close() error { return nil }

Loading
Loading