-
Notifications
You must be signed in to change notification settings - Fork 2.5k
fix(pool): wip, pool reauth should not interfere with handoff #3547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 14 commits
5fe0bfa
d39da69
8a629fb
6c54ab5
90bfdb3
07283ec
1bbf2e6
1428068
e7dc339
77c0c73
011ef96
e03396e
391b6c5
6ad9a67
0c4f8fb
acb55d8
d74671b
0e10cd7
afba8c2
4bc6d33
3020e3a
f886775
72cf74a
c715185
f14095b
e94cc9f
19f4080
4049d5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package auth | ||
|
||
import ( | ||
"runtime" | ||
"time" | ||
|
||
auth2 "github.com/redis/go-redis/v9/auth" | ||
"github.com/redis/go-redis/v9/internal/pool" | ||
) | ||
|
||
// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface. | ||
// It is used to re-authenticate the credentials when they are updated. | ||
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks. | ||
// It contains: | ||
// - reAuth: a function that takes the new credentials and returns an error if any. | ||
// - onErr: a function that takes an error and handles it. | ||
// - conn: the connection to re-authenticate. | ||
// - checkUsableTimeout: the timeout to wait for the connection to be usable - default is 1 second. | ||
type ConnReAuthCredentialsListener struct { | ||
// reAuth is called when the credentials are updated. | ||
reAuth func(conn *pool.Conn, credentials auth2.Credentials) error | ||
// onErr is called when an error occurs. | ||
onErr func(conn *pool.Conn, err error) | ||
// conn is the connection to re-authenticate. | ||
conn *pool.Conn | ||
// checkUsableTimeout is the timeout to wait for the connection to be usable | ||
// when the credentials are updated. | ||
// default is 1 second | ||
checkUsableTimeout time.Duration | ||
} | ||
|
||
// OnNext is called when the credentials are updated. | ||
// It calls the reAuth function with the new credentials. | ||
// If the reAuth function returns an error, it calls the onErr function with the error. | ||
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth2.Credentials) { | ||
if c.conn.IsClosed() { | ||
return | ||
} | ||
|
||
if c.reAuth == nil { | ||
return | ||
} | ||
|
||
var err error | ||
|
||
// this hard-coded timeout is not ideal | ||
timeout := time.After(c.checkUsableTimeout) | ||
// wait for the connection to be usable | ||
// this is important because the connection pool may be in the process of reconnecting the connection | ||
// and we don't want to interfere with that process | ||
// but we also don't want to block for too long, so incorporate a timeout | ||
for err == nil && !c.conn.Usable.CompareAndSwap(true, false) { | ||
select { | ||
case <-timeout: | ||
err = pool.ErrConnUnusableTimeout | ||
default: | ||
// small sleep to avoid busy looping | ||
time.Sleep(100 * time.Microsecond) | ||
// yield the thread to allow other goroutines to run | ||
runtime.Gosched() | ||
} | ||
} | ||
if err == nil { | ||
defer c.conn.SetUsable(true) | ||
} | ||
|
||
// This check just verifies that the connection is not in use. | ||
// If the connection is in use, we don't want to interfere with that. | ||
// As soon as the connection is not in use, we mark it as in use. | ||
for err == nil && !c.conn.Used.CompareAndSwap(false, true) { | ||
select { | ||
case <-timeout: | ||
err = pool.ErrConnUnusableTimeout | ||
default: | ||
// small sleep to avoid busy looping | ||
time.Sleep(100 * time.Microsecond) | ||
// yield the thread to allow other goroutines to run | ||
runtime.Gosched() | ||
} | ||
} | ||
|
||
// we timed out waiting for the connection to be usable | ||
// do not try to re-authenticate, instead call the onErr function | ||
// which will handle the error and close the connection if needed | ||
if err != nil { | ||
c.OnError(err) | ||
return | ||
} | ||
|
||
defer c.conn.Used.Store(false) | ||
// we set the usable flag, so restore it back to usable after we're done | ||
if err = c.reAuth(c.conn, credentials); err != nil { | ||
c.OnError(err) | ||
} | ||
} | ||
|
||
// OnError is called when an error occurs. | ||
// It can be called from both the credentials provider and the reAuth function. | ||
func (c *ConnReAuthCredentialsListener) OnError(err error) { | ||
if c.onErr == nil { | ||
return | ||
} | ||
|
||
c.onErr(c.conn, err) | ||
} | ||
|
||
// SetCheckUsableTimeout sets the timeout for the connection to be usable. | ||
func (c *ConnReAuthCredentialsListener) SetCheckUsableTimeout(timeout time.Duration) { | ||
c.checkUsableTimeout = timeout | ||
} | ||
|
||
// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener. | ||
// Implements the auth.CredentialsListener interface. | ||
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials auth2.Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener { | ||
return &ConnReAuthCredentialsListener{ | ||
conn: conn, | ||
reAuth: reAuth, | ||
onErr: onErr, | ||
checkUsableTimeout: 1 * time.Second, | ||
|
||
} | ||
} | ||
|
||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. | ||
var _ auth2.CredentialsListener = (*ConnReAuthCredentialsListener)(nil) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package auth | ||
|
||
import ( | ||
"sync" | ||
|
||
auth2 "github.com/redis/go-redis/v9/auth" | ||
"github.com/redis/go-redis/v9/internal/pool" | ||
) | ||
|
||
type CredentialsListeners struct { | ||
listeners map[*pool.Conn]auth2.CredentialsListener | ||
lock sync.RWMutex | ||
} | ||
|
||
func NewCredentialsListeners() *CredentialsListeners { | ||
return &CredentialsListeners{ | ||
listeners: make(map[*pool.Conn]auth2.CredentialsListener), | ||
} | ||
} | ||
|
||
func (c *CredentialsListeners) Add(poolCn *pool.Conn, listener auth2.CredentialsListener) { | ||
c.lock.Lock() | ||
defer c.lock.Unlock() | ||
if c.listeners == nil { | ||
c.listeners = make(map[*pool.Conn]auth2.CredentialsListener) | ||
} | ||
c.listeners[poolCn] = listener | ||
} | ||
|
||
func (c *CredentialsListeners) Get(poolCn *pool.Conn) (auth2.CredentialsListener, bool) { | ||
c.lock.RLock() | ||
defer c.lock.RUnlock() | ||
listener, ok := c.listeners[poolCn] | ||
return listener, ok | ||
} | ||
|
||
func (c *CredentialsListeners) Remove(poolCn *pool.Conn) { | ||
c.lock.Lock() | ||
defer c.lock.Unlock() | ||
delete(c.listeners, poolCn) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ package pool_test | |
import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import of Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
"bufio" | ||
"context" | ||
"net" | ||
"unsafe" | ||
ndyakov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
. "github.com/bsm/ginkgo/v2" | ||
|
@@ -124,20 +123,30 @@ var _ = Describe("Buffer Size Configuration", func() { | |
}) | ||
|
||
// Helper functions to extract buffer sizes using unsafe pointers | ||
// The struct layout must match pool.Conn exactly to avoid checkptr violations. | ||
// checkptr is Go's pointer safety checker, which ensures that unsafe pointer | ||
// conversions are valid. If the struct layouts do not match exactly, this can | ||
// cause runtime panics or incorrect memory access due to invalid pointer dereferencing. | ||
func getWriterBufSizeUnsafe(cn *pool.Conn) int { | ||
// Import required for atomic types | ||
type atomicBool struct{ _ uint32 } | ||
type atomicInt64 struct{ _ int64 } | ||
|
||
ndyakov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
cnPtr := (*struct { | ||
usedAt int64 | ||
netConn net.Conn | ||
rd *proto.Reader | ||
bw *bufio.Writer | ||
wr *proto.Writer | ||
// ... other fields | ||
id uint64 // First field in pool.Conn | ||
usedAt int64 // Second field (atomic) | ||
netConnAtomic interface{} // atomic.Value (interface{} has same size) | ||
rd *proto.Reader | ||
bw *bufio.Writer | ||
wr *proto.Writer | ||
// We only need fields up to bw, so we can stop here | ||
})(unsafe.Pointer(cn)) | ||
|
||
if cnPtr.bw == nil { | ||
return -1 | ||
} | ||
|
||
// bufio.Writer internal structure | ||
bwPtr := (*struct { | ||
err error | ||
buf []byte | ||
|
@@ -150,18 +159,20 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int { | |
|
||
func getReaderBufSizeUnsafe(cn *pool.Conn) int { | ||
cnPtr := (*struct { | ||
usedAt int64 | ||
netConn net.Conn | ||
rd *proto.Reader | ||
bw *bufio.Writer | ||
wr *proto.Writer | ||
// ... other fields | ||
id uint64 // First field in pool.Conn | ||
usedAt int64 // Second field (atomic) | ||
netConnAtomic interface{} // atomic.Value (interface{} has same size) | ||
rd *proto.Reader | ||
bw *bufio.Writer | ||
wr *proto.Writer | ||
// We only need fields up to rd, so we can stop here | ||
})(unsafe.Pointer(cn)) | ||
|
||
if cnPtr.rd == nil { | ||
return -1 | ||
} | ||
|
||
// proto.Reader internal structure | ||
rdPtr := (*struct { | ||
rd *bufio.Reader | ||
})(unsafe.Pointer(cnPtr.rd)) | ||
|
@@ -170,6 +181,7 @@ func getReaderBufSizeUnsafe(cn *pool.Conn) int { | |
return -1 | ||
} | ||
|
||
// bufio.Reader internal structure | ||
bufReaderPtr := (*struct { | ||
buf []byte | ||
rd interface{} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sleep duration
100 * time.Microsecond
is duplicated and hard-coded. Extract to a named constant likecheckUsablePollInterval = 100 * time.Microsecond
for better maintainability and to make the polling strategy more explicit.Copilot uses AI. Check for mistakes.