Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5fe0bfa
fix(pool): wip, pool reauth should not interfere with handoff
ndyakov Oct 14, 2025
d39da69
fix credListeners map
ndyakov Oct 14, 2025
8a629fb
fix race in tests
ndyakov Oct 14, 2025
6c54ab5
Merge branch 'master' into ndyakov/pool-reauth
ndyakov Oct 14, 2025
90bfdb3
better conn usable timeout
ndyakov Oct 14, 2025
07283ec
add design decision comment
ndyakov Oct 15, 2025
1bbf2e6
few small improvements
ndyakov Oct 15, 2025
1428068
update marked as queued
ndyakov Oct 15, 2025
e7dc339
add Used to clarify the state of the conn
ndyakov Oct 15, 2025
77c0c73
rename test
ndyakov Oct 15, 2025
011ef96
fix(test): fix flaky test
ndyakov Oct 15, 2025
e03396e
lock inside the listeners collection
ndyakov Oct 15, 2025
391b6c5
address pr comments
ndyakov Oct 15, 2025
6ad9a67
Update internal/auth/cred_listeners.go
ndyakov Oct 15, 2025
0c4f8fb
Update internal/pool/buffer_size_test.go
ndyakov Oct 15, 2025
acb55d8
wip refactor entraid
ndyakov Oct 16, 2025
d74671b
fix maintnotif pool hook
ndyakov Oct 17, 2025
0e10cd7
fix mocks
ndyakov Oct 17, 2025
afba8c2
fix nil listener
ndyakov Oct 17, 2025
4bc6d33
sync and async reauth based on conn lifecycle
ndyakov Oct 17, 2025
3020e3a
be able to reject connection OnGet
ndyakov Oct 17, 2025
f886775
pass hooks so the tests can observe reauth
ndyakov Oct 17, 2025
72cf74a
give some time for the background to execute commands
ndyakov Oct 17, 2025
c715185
fix tests
ndyakov Oct 17, 2025
f14095b
only async reauth
ndyakov Oct 17, 2025
e94cc9f
Merge branch 'master' into ndyakov/pool-reauth
ndyakov Oct 17, 2025
19f4080
Update internal/pool/pool.go
ndyakov Oct 17, 2025
4049d5e
Update internal/auth/streaming/pool_hook.go
ndyakov Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions auth/conn_reauth_credentials_listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package auth

import (
"runtime"
"time"

"github.com/redis/go-redis/v9/internal/pool"
)

// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
// It is used to re-authenticate the credentials when they are updated.
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
// It contains:
// - reAuth: a function that takes the new credentials and returns an error if any.
// - onErr: a function that takes an error and handles it.
// - conn: the connection to re-authenticate.
// - checkUsableTimeout: the timeout to wait for the connection to be usable - default is 1 second.
type ConnReAuthCredentialsListener struct {
// reAuth is called when the credentials are updated.
reAuth func(conn *pool.Conn, credentials Credentials) error
// onErr is called when an error occurs.
onErr func(conn *pool.Conn, err error)
// conn is the connection to re-authenticate.
conn *pool.Conn
// checkUsableTimeout is the timeout to wait for the connection to be usable
// when the credentials are updated.
// default is 1 second
checkUsableTimeout time.Duration
}

// OnNext is called when the credentials are updated.
// It calls the reAuth function with the new credentials.
// If the reAuth function returns an error, it calls the onErr function with the error.
func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
if c.conn.IsClosed() {
return
}

if c.reAuth == nil {
return
}

var err error

// this hard-coded timeout is not ideal
timeout := time.After(c.checkUsableTimeout)
// wait for the connection to be usable
// this is important because the connection pool may be in the process of reconnecting the connection
// and we don't want to interfere with that process
// but we also don't want to block for too long, so incorporate a timeout
for err == nil && !c.conn.Usable.CompareAndSwap(true, false) {
select {
case <-timeout:
err = pool.ErrConnUnusableTimeout
default:
runtime.Gosched()
}
}
if err == nil {
defer c.conn.SetUsable(true)
}

for err == nil && !c.conn.Used.CompareAndSwap(false, true) {
select {
case <-timeout:
err = pool.ErrConnUnusableTimeout
default:
runtime.Gosched()
}
}

// we timed out waiting for the connection to be usable
// do not try to re-authenticate, instead call the onErr function
// which will handle the error and close the connection if needed
if err != nil {
c.OnError(err)
return
}

defer c.conn.Used.Store(false)
// we set the usable flag, so restore it back to usable after we're done
if err = c.reAuth(c.conn, credentials); err != nil {
c.OnError(err)
}
}

// OnError is called when an error occurs.
// It can be called from both the credentials provider and the reAuth function.
func (c *ConnReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}

c.onErr(c.conn, err)
}

// SetCheckUsableTimeout sets the timeout for the connection to be usable.
func (c *ConnReAuthCredentialsListener) SetCheckUsableTimeout(timeout time.Duration) {
c.checkUsableTimeout = timeout
}

// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
// Implements the auth.CredentialsListener interface.
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
return &ConnReAuthCredentialsListener{
conn: conn,
reAuth: reAuth,
onErr: onErr,
checkUsableTimeout: 1 * time.Second,
}
}

// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
2 changes: 1 addition & 1 deletion auth/reauth_credentials_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on
}

// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
10 changes: 6 additions & 4 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ func isRedisError(err error) bool {

func isBadConn(err error, allowTimeout bool, addr string) bool {
switch err {
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case pool.ErrConnUnusableTimeout:
return true
}

if isRedisError(err) {
Expand Down
35 changes: 22 additions & 13 deletions internal/pool/buffer_size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pool_test
import (
Copy link

Copilot AI Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import of net package was removed, but net.Conn is still referenced in the comment at line 142 as 'net.Conn'. While this is just a comment, it could cause confusion. Consider either keeping the import or updating the comment to remove the package reference.

Copilot uses AI. Check for mistakes.

"bufio"
"context"
"net"
"unsafe"

. "github.com/bsm/ginkgo/v2"
Expand Down Expand Up @@ -124,20 +123,27 @@ var _ = Describe("Buffer Size Configuration", func() {
})

// Helper functions to extract buffer sizes using unsafe pointers
// The struct layout must match pool.Conn exactly to avoid checkptr violations
func getWriterBufSizeUnsafe(cn *pool.Conn) int {
// Import required for atomic types
type atomicBool struct{ _ uint32 }
type atomicInt64 struct{ _ int64 }

cnPtr := (*struct {
usedAt int64
netConn net.Conn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// ... other fields
id uint64 // First field in pool.Conn
usedAt int64 // Second field (atomic)
netConnAtomic interface{} // atomic.Value (interface{} has same size)
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// We only need fields up to bw, so we can stop here
})(unsafe.Pointer(cn))

if cnPtr.bw == nil {
return -1
}

// bufio.Writer internal structure
bwPtr := (*struct {
err error
buf []byte
Expand All @@ -150,18 +156,20 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int {

func getReaderBufSizeUnsafe(cn *pool.Conn) int {
cnPtr := (*struct {
usedAt int64
netConn net.Conn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// ... other fields
id uint64 // First field in pool.Conn
usedAt int64 // Second field (atomic)
netConnAtomic interface{} // atomic.Value (interface{} has same size)
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// We only need fields up to rd, so we can stop here
})(unsafe.Pointer(cn))

if cnPtr.rd == nil {
return -1
}

// proto.Reader internal structure
rdPtr := (*struct {
rd *bufio.Reader
})(unsafe.Pointer(cnPtr.rd))
Expand All @@ -170,6 +178,7 @@ func getReaderBufSizeUnsafe(cn *pool.Conn) int {
return -1
}

// bufio.Reader internal structure
bufReaderPtr := (*struct {
buf []byte
rd interface{}
Expand Down
72 changes: 55 additions & 17 deletions internal/pool/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func generateConnID() uint64 {
}

type Conn struct {
// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection

usedAt int64 // atomic

// Lock-free netConn access using atomic.Value
Expand All @@ -54,7 +57,34 @@ type Conn struct {
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex

Inited atomic.Bool
// Design note:
// Why have both Usable and Used?
// _Usable_ is used to mark a connection as safe for use by clients, the connection can still
// be in the pool but not Usable at the moment (e.g. handoff in progress).
// _Used_ is used to mark a connection as used when a command is going to be processed on that connection.
// this is going to happen once the connection is picked from the pool.
//
// If a background operation needs to use the connection, it will mark it as Not Usable and only use it when it
// is not in use. That way, the connection won't be used to send multiple commands at the same time and
// potentially corrupt the command stream.

// Usable flag to mark connection as safe for use
// It is false before initialization and after a handoff is marked
// It will be false during other background operations like re-authentication
Usable atomic.Bool

// Used flag to mark connection as used when a command is going to be
// processed on that connection. This is used to prevent a race condition with
// background operations that may execute commands, like re-authentication.
Used atomic.Bool

// Inited flag to mark connection as initialized, this is almost the same as usable
// but it is used to make sure we don't initialize a network connection twice
// On handoff, the network connection is replaced, but the Conn struct is reused
// this flag will be set to false when the network connection is replaced and
// set to true after the new network connection is initialized
Inited atomic.Bool

pooled bool
pubsub bool
closed atomic.Bool
Expand All @@ -75,11 +105,7 @@ type Conn struct {
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error

// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection

// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts

// Atomic handoff state to prevent race conditions
Expand Down Expand Up @@ -116,7 +142,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})

// Initialize atomic state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.Usable.Store(false) // false initially, set to true after initialization
cn.handoffRetriesAtomic.Store(0) // 0 initially

// Initialize handoff state atomically
Expand Down Expand Up @@ -162,12 +188,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) {

// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
return cn.Usable.Load()
}

// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
cn.Usable.Store(usable)
}

// getHandoffState returns the current handoff state atomically (lock-free).
Expand Down Expand Up @@ -455,9 +481,24 @@ func (cn *Conn) MarkQueuedForHandoff() error {
const maxRetries = 50
const baseDelay = time.Microsecond

connAcquired := false
for attempt := 0; attempt < maxRetries; attempt++ {
currentState := cn.getHandoffState()
// If CAS failed, add exponential backoff to reduce contention
// the delay will be 1, 2, 4... up to 512 microseconds
// Moving this to the top of the loop to avoid "continue" without delay
if attempt > 0 && attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
Copy link

Copilot AI Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 10 for capping exponential growth is unexplained. Add a named constant like maxExponentShift = 10 with a comment explaining why this value prevents excessive delays while maintaining backoff effectiveness.

Copilot uses AI. Check for mistakes.

time.Sleep(delay)
}

// first we need to mark the connection as not usable
// to prevent the pool from returning it to the caller
if !connAcquired && !cn.Usable.CompareAndSwap(true, false) {
continue
}
connAcquired = true

currentState := cn.getHandoffState()
// Check if marked for handoff
if !currentState.ShouldHandoff {
return errors.New("connection was not marked for handoff")
Expand All @@ -472,16 +513,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {

// Atomic compare-and-swap to update state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
cn.setUsable(false)
// queue the handoff for processing
// the connection is now "acquired" (marked as not usable) by the handoff
// and it won't be returned to any other callers until the handoff is complete
return nil
}

// If CAS failed, add exponential backoff to reduce contention
// the delay will be 1, 2, 4... up to 512 microseconds
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
}

return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
Expand Down Expand Up @@ -527,7 +564,8 @@ func (cn *Conn) ClearHandoffState() {
// Atomically set clean state
cn.setHandoffState(cleanState)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
// Clearing handoff state also means the connection is usable again
cn.setUsable(true)
}

// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
Expand Down
Loading
Loading