diff --git a/.gitignore b/.gitignore index 0d99709e34..5fe0716e29 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage.txt **/coverage.txt .vscode tmp/* + +# Hitless upgrade documentation (temporary) +hitless/docs/ diff --git a/adapters.go b/adapters.go new file mode 100644 index 0000000000..c5cb84f296 --- /dev/null +++ b/adapters.go @@ -0,0 +1,144 @@ +package redis + +import ( + "context" + "errors" + "net" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand. +var ErrInvalidCommand = errors.New("invalid command type") + +// ErrInvalidPool is returned when the pool type is not supported. +var ErrInvalidPool = errors.New("invalid pool type") + +// NewClientAdapter creates a new client adapter for regular Redis clients. +func NewClientAdapter(client *baseClient) interfaces.ClientInterface { + return &clientAdapter{client: client} +} + +// clientAdapter adapts a Redis client to implement interfaces.ClientInterface. +type clientAdapter struct { + client *baseClient +} + +// GetOptions returns the client options. +func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface { + return &optionsAdapter{options: ca.client.opt} +} + +// GetPushProcessor returns the client's push notification processor. +func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor { + return &pushProcessorAdapter{processor: ca.client.pushProcessor} +} + +// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface. +type optionsAdapter struct { + options *Options +} + +// GetReadTimeout returns the read timeout. +func (oa *optionsAdapter) GetReadTimeout() time.Duration { + return oa.options.ReadTimeout +} + +// GetWriteTimeout returns the write timeout. +func (oa *optionsAdapter) GetWriteTimeout() time.Duration { + return oa.options.WriteTimeout +} + +// GetAddr returns the connection address. +func (oa *optionsAdapter) GetAddr() string { + return oa.options.Addr +} + +// IsTLSEnabled returns true if TLS is enabled. +func (oa *optionsAdapter) IsTLSEnabled() bool { + return oa.options.TLSConfig != nil +} + +// GetProtocol returns the protocol version. +func (oa *optionsAdapter) GetProtocol() int { + return oa.options.Protocol +} + +// GetPoolSize returns the connection pool size. +func (oa *optionsAdapter) GetPoolSize() int { + return oa.options.PoolSize +} + +// NewDialer returns a new dialer function for the connection. +func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) { + baseDialer := oa.options.NewDialer() + return func(ctx context.Context) (net.Conn, error) { + // Extract network and address from the options + network := "tcp" + addr := oa.options.Addr + return baseDialer(ctx, network, addr) + } +} + +// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout +type connectionAdapter struct { + conn *pool.Conn +} + +// Close closes the connection. +func (ca *connectionAdapter) Close() error { + return ca.conn.Close() +} + +// IsUsable returns true if the connection is safe to use for new commands. +func (ca *connectionAdapter) IsUsable() bool { + return ca.conn.IsUsable() +} + +// GetPoolConnection returns the underlying pool connection. +func (ca *connectionAdapter) GetPoolConnection() *pool.Conn { + return ca.conn +} + +// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. +// These timeouts remain active until explicitly cleared. +func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout) +} + +// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. +// After the deadline, timeouts automatically revert to normal values. +func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline) +} + +// ClearRelaxedTimeout clears relaxed timeouts for this connection. +func (ca *connectionAdapter) ClearRelaxedTimeout() { + ca.conn.ClearRelaxedTimeout() +} + +// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor. +type pushProcessorAdapter struct { + processor push.NotificationProcessor +} + +// RegisterHandler registers a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error { + if pushHandler, ok := handler.(push.NotificationHandler); ok { + return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected) + } + return errors.New("handler must implement push.NotificationHandler") +} + +// UnregisterHandler removes a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error { + return ppa.processor.UnregisterHandler(pushNotificationName) +} + +// GetHandler returns the handler for a specific push notification name. +func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} { + return ppa.processor.GetHandler(pushNotificationName) +} diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go new file mode 100644 index 0000000000..e8945557c6 --- /dev/null +++ b/async_handoff_integration_test.go @@ -0,0 +1,309 @@ +package redis + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9/hitless" + "github.com/redis/go-redis/v9/internal/pool" +) + +// mockNetConn implements net.Conn for testing +type mockNetConn struct { + addr string +} + +func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockAddr struct { + addr string +} + +func (m *mockAddr) Network() string { return "tcp" } +func (m *mockAddr) String() string { return m.addr } + +// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow +func TestEventDrivenHandoffIntegration(t *testing.T) { + t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) { + // Create a base dialer for testing + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + // Create processor with event-driven handoff support + processor := hitless.NewRedisConnectionProcessor(3, baseDialer, nil, nil) + defer processor.Shutdown(context.Background()) + + // Create a test pool with the processor + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + ConnectionProcessor: processor, + PoolSize: 5, + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Set the pool reference in the processor for connection removal on handoff failure + processor.SetPool(testPool) + + ctx := context.Background() + + // Get a connection and mark it for handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + // Set initialization function + initConnCalled := false + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + initConnCalled = true + return nil + } + conn.SetInitConnFunc(initConnFunc) + + // Mark connection for handoff + conn.MarkForHandoff("new-endpoint:6379", 12345) + + // Return connection to pool - this should queue handoff + testPool.Put(ctx, conn) + + // Verify handoff was queued + if !processor.IsHandoffPending(conn) { + t.Error("Handoff should be queued in pending map") + } + + // Try to get the same connection - should be skipped due to pending handoff + conn2, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get second connection: %v", err) + } + + // Should get a different connection (the pending one should be skipped) + if conn == conn2 { + t.Error("Should have gotten a different connection while handoff is pending") + } + + // Return the second connection + testPool.Put(ctx, conn2) + + // Wait for handoff to complete + time.Sleep(200 * time.Millisecond) + + // Verify handoff completed (removed from pending map) + if processor.IsHandoffPending(conn) { + t.Error("Handoff should have completed and been removed from pending map") + } + + if !initConnCalled { + t.Error("InitConn should have been called during handoff") + } + + // Now the original connection should be available again + conn3, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get third connection: %v", err) + } + + // Could be the original connection (now handed off) or a new one + testPool.Put(ctx, conn3) + }) + + t.Run("ConcurrentHandoffs", func(t *testing.T) { + // Create a base dialer that simulates slow handoffs + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(50 * time.Millisecond) // Simulate network delay + return &mockNetConn{addr: addr}, nil + } + + processor := hitless.NewRedisConnectionProcessor(3, baseDialer, nil, nil) + defer processor.Shutdown(context.Background()) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + ConnectionProcessor: processor, + PoolSize: 10, + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + var wg sync.WaitGroup + + // Start multiple concurrent handoffs + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Get connection + conn, err := testPool.Get(ctx) + if err != nil { + t.Errorf("Failed to get connection %d: %v", id, err) + return + } + + // Set initialization function + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + return nil + } + conn.SetInitConnFunc(initConnFunc) + + // Mark for handoff + conn.MarkForHandoff("new-endpoint:6379", int64(id)) + + // Return to pool (starts async handoff) + testPool.Put(ctx, conn) + }(i) + } + + wg.Wait() + + // Wait for all handoffs to complete + time.Sleep(300 * time.Millisecond) + + // Verify pool is still functional + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err) + } + testPool.Put(ctx, conn) + }) + + t.Run("HandoffFailureRecovery", func(t *testing.T) { + // Create a failing base dialer + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}} + } + + processor := hitless.NewRedisConnectionProcessor(3, failingDialer, nil, nil) + defer processor.Shutdown(context.Background()) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + ConnectionProcessor: processor, + PoolSize: 3, + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + + // Get connection and mark for handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + conn.MarkForHandoff("unreachable-endpoint:6379", 12345) + + // Return to pool (starts async handoff that will fail) + testPool.Put(ctx, conn) + + // Wait for handoff to fail + time.Sleep(200 * time.Millisecond) + + // Connection should be removed from pending map after failed handoff + if processor.IsHandoffPending(conn) { + t.Error("Connection should be removed from pending map after failed handoff") + } + + // Pool should still be functional + conn2, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Pool should still be functional: %v", err) + } + + // In event-driven approach, the original connection remains in pool + // even after failed handoff (it's still a valid connection) + // We might get the same connection or a different one + testPool.Put(ctx, conn2) + }) + + t.Run("GracefulShutdown", func(t *testing.T) { + // Create a slow base dialer + slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(100 * time.Millisecond) + return &mockNetConn{addr: addr}, nil + } + + processor := hitless.NewRedisConnectionProcessor(3, slowDialer, nil, nil) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + ConnectionProcessor: processor, + PoolSize: 2, + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + + // Start a handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + testPool.Put(ctx, conn) + + // Verify handoff was queued + if !processor.IsHandoffPending(conn) { + t.Error("Handoff should be queued in pending map") + } + + // Give the handoff a moment to start processing + time.Sleep(50 * time.Millisecond) + + // Shutdown processor gracefully + // Use a longer timeout to account for slow dialer (100ms) plus processing overhead + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = processor.Shutdown(shutdownCtx) + if err != nil { + t.Errorf("Graceful shutdown should succeed: %v", err) + } + + // Handoff should have completed (removed from pending map) + if processor.IsHandoffPending(conn) { + t.Error("Handoff should have completed and been removed from pending map after shutdown") + } + }) +} diff --git a/commands.go b/commands.go index c0358001d1..8ee072904d 100644 --- a/commands.go +++ b/commands.go @@ -193,6 +193,7 @@ type Cmdable interface { ClientID(ctx context.Context) *IntCmd ClientUnblock(ctx context.Context, id int64) *IntCmd ClientUnblockWithError(ctx context.Context, id int64) *IntCmd + ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd ConfigResetStat(ctx context.Context) *StatusCmd ConfigSet(ctx context.Context, parameter, value string) *StatusCmd @@ -518,6 +519,20 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd { return cmd } +// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades. +// When enabled, the client will receive push notifications about Redis maintenance events. +func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd { + args := []interface{}{"client", "maint_notifications"} + if enabled { + args = append(args, "on", "moving-endpoint-type", endpointType) + } else { + args = append(args, "off") + } + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + // ------------------------------------------------------------------------------------------------ func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd { diff --git a/example/pubsub/go.mod b/example/pubsub/go.mod new file mode 100644 index 0000000000..731a92839d --- /dev/null +++ b/example/pubsub/go.mod @@ -0,0 +1,12 @@ +module github.com/redis/go-redis/example/pubsub + +go 1.18 + +replace github.com/redis/go-redis/v9 => ../.. + +require github.com/redis/go-redis/v9 v9.11.0 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/example/pubsub/go.sum b/example/pubsub/go.sum new file mode 100644 index 0000000000..d64ea0303f --- /dev/null +++ b/example/pubsub/go.sum @@ -0,0 +1,6 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= diff --git a/example/pubsub/main.go b/example/pubsub/main.go new file mode 100644 index 0000000000..bf80bba64d --- /dev/null +++ b/example/pubsub/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/hitless" +) + +var ctx = context.Background() + +// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management. +// It was used to find regressions in pool management in hitless mode. +// Please don't use it as a reference for how to use pubsub. +func main() { + wg := &sync.WaitGroup{} + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + }, + }) + _ = rdb.FlushDB(ctx).Err() + + go func() { + for { + time.Sleep(2 * time.Second) + fmt.Printf("pool stats: %+v\n", rdb.PoolStats()) + } + }() + err := rdb.Ping(ctx).Err() + if err != nil { + panic(err) + } + if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil { + panic(err) + } + fmt.Println("published", rdb.Get(ctx, "published").Val()) + fmt.Println("received", rdb.Get(ctx, "received").Val()) + subCtx, cancelSubCtx := context.WithCancel(ctx) + pubCtx, cancelPublishers := context.WithCancel(ctx) + for i := 0; i < 10; i++ { + wg.Add(1) + go subscribe(subCtx, rdb, "test", i, wg) + } + time.Sleep(time.Second) + cancelSubCtx() + time.Sleep(time.Second) + subCtx, cancelSubCtx = context.WithCancel(ctx) + for i := 0; i < 10; i++ { + if err := rdb.Incr(ctx, "publishers").Err(); err != nil { + panic(err) + } + wg.Add(1) + go floodThePool(pubCtx, rdb, wg) + } + + for i := 0; i < 500; i++ { + if err := rdb.Incr(ctx, "subscribers").Err(); err != nil { + panic(err) + } + wg.Add(1) + go subscribe(subCtx, rdb, "test2", i, wg) + } + time.Sleep(5 * time.Second) + fmt.Println("canceling publishers") + cancelPublishers() + time.Sleep(10 * time.Second) + fmt.Println("canceling subscribers") + cancelSubCtx() + wg.Wait() + published, err := rdb.Get(ctx, "published").Result() + received, err := rdb.Get(ctx, "received").Result() + publishers, err := rdb.Get(ctx, "publishers").Result() + subscribers, err := rdb.Get(ctx, "subscribers").Result() + fmt.Printf("publishers: %s\n", publishers) + fmt.Printf("published: %s\n", published) + fmt.Printf("subscribers: %s\n", subscribers) + fmt.Printf("received: %s\n", received) + publishedInt, err := rdb.Get(ctx, "published").Int() + subscribersInt, err := rdb.Get(ctx, "subscribers").Int() + fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt) + + time.Sleep(2 * time.Second) +} + +func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + } + err := rdb.Publish(ctx, "test2", "hello").Err() + if err != nil { + // noop + //log.Println("publish error:", err) + } + + err = rdb.Incr(ctx, "published").Err() + if err != nil { + // noop + //log.Println("incr error:", err) + } + time.Sleep(10 * time.Nanosecond) + } +} + +func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) { + defer wg.Done() + rec := rdb.Subscribe(ctx, topic) + recChan := rec.Channel() + for { + select { + case <-ctx.Done(): + rec.Close() + return + default: + select { + case <-ctx.Done(): + rec.Close() + return + case msg := <-recChan: + err := rdb.Incr(ctx, "received").Err() + if err != nil { + log.Println("incr error:", err) + } + _ = msg // Use the message to avoid unused variable warning + } + } + } +} diff --git a/hitless/README.md b/hitless/README.md new file mode 100644 index 0000000000..826dd66cb2 --- /dev/null +++ b/hitless/README.md @@ -0,0 +1,561 @@ +# Hitless Upgrades Package + +This package provides hitless upgrade functionality for Redis clients, enabling seamless cluster maintenance operations without dropping connections or failing commands. The system automatically handles Redis cluster topology changes through push notifications and intelligent connection management. + +## πŸš€ Quick Start + +```go +import "github.com/redis/go-redis/v9/hitless" + +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // Required for push notifications + // Enable hitless upgrades with configuration + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, // Enable maintenance notifications + EndpointType: hitless.EndpointTypeAuto, // Auto-detect endpoint type + RelaxedTimeout: 30 * time.Second, // Extended timeout during migrations + HandoffTimeout: 15 * time.Second, // Max time for connection handoff + PostHandoffRelaxedDuration: 10 * time.Second, // Keep relaxed timeout after handoff + LogLevel: 1, // Warning level logging + }, +}) + +// That's it! Hitless upgrades now work automatically +result, err := client.Get(ctx, "key") // Seamlessly handles cluster changes +``` + +## πŸ“‹ Supported Client Types + +Hitless upgrades are supported by the following client types: + +- βœ… **`redis.Client`** - Standard Redis client +- βœ… **`redis.ClusterClient`** - Redis Cluster client +- βœ… **`redis.SentinelClient`** - Redis Sentinel client +- ❌ **`redis.RingClient`** - Not supported (no hitless upgrade integration) + +All supported clients require **Protocol: 3 (RESP3)** for push notification support. + +## πŸ”„ How Hitless Upgrades Work + +The hitless upgrade system provides seamless Redis cluster maintenance through push notifications and connection-level management: + +### **Architecture Overview** + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Redis Client β”‚ β”‚ Hitless Manager β”‚ β”‚ Connection Pool β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚Push Notif β”‚ β”‚ β”‚ β”‚MOVING Op β”‚ β”‚ β”‚ β”‚Connection β”‚ β”‚ +β”‚ β”‚Processor β”‚ │───── β”‚Tracker β”‚ β”‚ β”‚ β”‚Processor β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚Command β”‚ β”‚ β”‚ β”‚Notification β”‚ │───── β”‚Handoff β”‚ β”‚ +β”‚ β”‚Execution β”‚ β”‚ β”‚ β”‚Handler β”‚ β”‚ β”‚ β”‚Workers β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### **Push Notification Types** + +The system handles the following Redis push notifications: + +- **`MOVING`** - Connection handoff to new endpoint (per-connection) +- **`MIGRATING`** - Slot migration in progress (applies relaxed timeouts) +- **`MIGRATED`** - Slot migration completed (clears relaxed timeouts) +- **`FAILING_OVER`** - Failover in progress (applies relaxed timeouts) +- **`FAILED_OVER`** - Failover completed (clears relaxed timeouts) + +### **Operation Flow** + +#### **1. πŸ—οΈ Initialization** +```go +import "github.com/redis/go-redis/v9/hitless" + +// When hitless upgrades are enabled +client := redis.NewClient(&redis.Options{ + Protocol: 3, // RESP3 required for push notifications + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, // Enable maintenance notifications + }, +}) + +// Internally, the client: +// 1. Sends CLIENT MAINT_NOTIFICATIONS command during handshake +// 2. Creates HitlessManager with configuration +// 3. Registers push notification handlers for all upgrade events +// 4. Creates ConnectionProcessor for handoff management +// 5. Integrates with existing connection pool +``` + +#### **2. πŸ“‘ Push Notification Handling** +```go +// Redis server sends push notifications during cluster operations: +// Format: ["MOVING", seqNum, timeS, endpoint] +// Format: ["MIGRATING", slot] +// Format: ["MIGRATED", slot] +// Format: ["FAILING_OVER", node] +// Format: ["FAILED_OVER", node] + +// Example MOVING notification flow: +// 1. Redis sends: ["MOVING", "12345", "30", "10.0.0.2:6379"] +// 2. Push processor routes to hitless notification handler +// 3. Handler marks specific connection for handoff +// 4. Background workers process the handoff asynchronously +``` + +#### **3. πŸ”„ Connection Management** +```go +// MOVING notification handling (per-connection): +// β”œβ”€β”€ Parse sequence ID, timeout, and new endpoint +// β”œβ”€β”€ Mark specific connection for handoff +// β”œβ”€β”€ Queue handoff request for background processing +// └── Track operation with composite key (seqID + connID) + +// MIGRATING/FAILING_OVER notification handling: +// β”œβ”€β”€ Apply relaxed timeouts to the receiving connection +// β”œβ”€β”€ Allow commands to continue with extended timeouts +// └── Connection-specific timeout management + +// MIGRATED/FAILED_OVER notification handling: +// β”œβ”€β”€ Clear relaxed timeouts from the receiving connection +// β”œβ”€β”€ Resume normal timeout behavior +// └── Per-connection state cleanup +``` + +#### **4. πŸ”€ Connection Handoff Process** +```go +// Background handoff workers process requests: + +func (processor *RedisConnectionProcessor) processHandoffRequest(request HandoffRequest) { + // 1. Create new connection to target endpoint + newConn, err := processor.dialNewConnection(request.NewEndpoint) + if err != nil { + // Handoff failed, remove connection from pool + processor.pool.Remove(request.Conn) + return err + } + + // 2. Replace connection in pool atomically + // The pool handles connection state transfer internally + err = processor.pool.ReplaceConnection(request.Conn, newConn) + if err != nil { + newConn.Close() + return err + } + + // 3. Apply post-handoff relaxed timeout to new connection + newConn.SetRelaxedTimeoutWithDeadline( + config.RelaxedTimeout, + config.RelaxedTimeout, + time.Now().Add(config.PostHandoffRelaxedDuration), + ) + + // 4. Notify hitless manager of completion + processor.hitlessManager.CompleteOperationWithConnID(seqID, connID) + + // 5. Old connection is closed by pool replacement +} +``` + +#### **5. ⏱️ Timeout Management** +```go +// Timeout hierarchy (highest priority first): + +// 1. Context Deadline (if set and shorter) +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +result := client.Get(ctx, "key") // Uses 5s even during MIGRATING + +// 2. Relaxed Timeout (during MIGRATING/FAILING_OVER on specific connection) +ctx := context.Background() // No deadline +result := client.Get(ctx, "key") // Uses RelaxedTimeout (30s) if connection has relaxed timeout + +// 3. Normal Client Timeout (default) +ctx := context.Background() // No deadline, normal state +result := client.Get(ctx, "key") // Uses client ReadTimeout (5s) + +// Note: Relaxed timeouts are applied per-connection, not globally +// Only connections that receive MIGRATING/FAILING_OVER notifications get relaxed timeouts +``` + +#### **6. 🎯 Command Execution Flow** +```go +// Every command goes through this flow: + +client.Get(ctx, "key") +// ↓ +// 1. Get connection from pool +conn := pool.Get(ctx) +// ↓ +// 2. Connection processor processes pending notifications +processor.ProcessConnectionOnGet(ctx, conn) +// ↓ +// 3. Execute command (timeout determined by connection state) +// - If connection has relaxed timeout: uses RelaxedTimeout +// - Otherwise: uses normal client timeout +// - Context deadline always takes precedence if shorter +result := conn.ExecuteCommand(ctx, cmd) +// ↓ +// 4. Return connection to pool +// - Check if connection is marked for handoff +// - Queue handoff if needed +processor.ProcessConnectionOnPut(ctx, conn) +// ↓ +// 5. Return result to application +return result +``` + +## πŸ—οΈ Component Architecture + +### **HitlessManager** +- **MOVING operation tracking** with composite keys (seqID + connID) +- **Push notification handling** for all upgrade events +- **Operation deduplication** to handle duplicate notifications +- **Configuration management** with sensible defaults + +### **RedisConnectionProcessor** +- **Connection handoff management** with background workers +- **Dynamic worker scaling** based on load (min/max workers) +- **Queue management** for handoff requests with timeout handling +- **Pool integration** for connection replacement + +### **Push Notification System** +- **Automatic handler registration** for upgrade events +- **RESP3 protocol parsing** of Redis push notifications +- **Per-connection event routing** to appropriate handlers +- **Protected handler registration** (cannot be overwritten) + +### **Connection Pool Integration** +- **Existing pool architecture** with processor integration +- **Connection marking** for handoff operations +- **Atomic connection replacement** during handoffs +- **Per-connection timeout management** + +## βš™οΈ Configuration Options + +```go +import "github.com/redis/go-redis/v9/hitless" + +type HitlessUpgradeConfig struct { + // Core settings - Type-safe enums + Enabled hitless.MaintNotificationsMode // Maintenance notifications mode + EndpointType hitless.EndpointType // Endpoint type for MOVING notifications + + // Timeout settings + RelaxedTimeout time.Duration // Timeout during MIGRATING/FAILING_OVER (default: 30s) + RelaxedTimeoutDuration time.Duration // Legacy alias for RelaxedTimeout (default: 30s) + HandoffTimeout time.Duration // Max time for connection handoff (default: 15s) + PostHandoffRelaxedDuration time.Duration // Keep relaxed timeout after handoff (default: 10s) + + // Worker settings (auto-calculated based on pool size if 0) + MinWorkers int // Minimum handoff workers (default: max(1, poolSize/25)) + MaxWorkers int // Maximum handoff workers (default: max(MinWorkers*4, poolSize/5)) + HandoffQueueSize int // Handoff request queue size (default: MaxWorkers*10, capped by poolSize) + + // Advanced settings + ScaleDownDelay time.Duration // Delay before scaling down workers (default: 2s) + LogLevel int // 0=errors, 1=warnings, 2=info, 3=debug (default: 1) +} + +// Maintenance Notifications Mode (type-safe enum) +const ( + MaintNotificationsDisabled hitless.MaintNotificationsMode = "disabled" // Disable maintenance notifications + MaintNotificationsEnabled hitless.MaintNotificationsMode = "enabled" // Force enable, fail on errors + MaintNotificationsAuto hitless.MaintNotificationsMode = "auto" // Auto-enable, disable on errors +) + +// Endpoint Type for MOVING notifications (type-safe enum) +const ( + EndpointTypeAuto hitless.EndpointType = "auto" // Auto-detect based on connection + EndpointTypeInternalIP hitless.EndpointType = "internal-ip" // Request internal IP address + EndpointTypeInternalFQDN hitless.EndpointType = "internal-fqdn" // Request internal FQDN + EndpointTypeExternalIP hitless.EndpointType = "external-ip" // Request external IP address + EndpointTypeExternalFQDN hitless.EndpointType = "external-fqdn" // Request external FQDN + EndpointTypeNone hitless.EndpointType = "none" // Request null endpoint +) +``` + +## 🎯 Usage Examples + +### **Basic Usage (Recommended)** +```go +import "github.com/redis/go-redis/v9/hitless" + +// Minimal configuration - uses sensible defaults +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // Required for push notifications + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, // Enable maintenance notifications + }, +}) +``` + +### **Custom Configuration** +```go +import "github.com/redis/go-redis/v9/hitless" + +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, // Enable maintenance notifications + EndpointType: hitless.EndpointTypeInternalIP, // Request internal IP addresses + RelaxedTimeout: 45 * time.Second, // Longer timeout for slow operations + HandoffTimeout: 20 * time.Second, // More time for handoffs + PostHandoffRelaxedDuration: 15 * time.Second, // Extended post-handoff period + LogLevel: 2, // Info level logging + }, +}) +``` + +### **Auto Mode (Graceful Fallback)** +```go +import "github.com/redis/go-redis/v9/hitless" + +// Auto mode - enables hitless upgrades if supported, disables on errors +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsAuto, // Auto-enable, disable on errors + }, +}) +``` + +### **Cluster Client** +```go +import "github.com/redis/go-redis/v9/hitless" + +clusterClient := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{"localhost:7000", "localhost:7001", "localhost:7002"}, + Protocol: 3, + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + // Configuration options same as regular client + }, +}) +``` + +### **Sentinel Client** +```go +import "github.com/redis/go-redis/v9/hitless" + +sentinelClient := redis.NewSentinelClient(&redis.SentinelOptions{ + MasterName: "mymaster", + SentinelAddrs: []string{"localhost:26379"}, + Protocol: 3, + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + }, +}) +``` + +## πŸ”§ Automatic Operation + +Once configured, hitless upgrades work completely automatically: + +- βœ… **No manual registration** - Push handlers are registered automatically +- βœ… **No state management** - MOVING operations are tracked automatically +- βœ… **No timeout management** - Relaxed timeouts are applied per-connection +- βœ… **No handoff coordination** - Connection handoffs happen transparently +- βœ… **No cleanup required** - Resources and workers are managed automatically + +Your application code remains unchanged - just enable the feature and it works! + +## πŸ“‹ Requirements + +### **Protocol Requirements** +- **RESP3 Protocol Required**: Must set `Protocol: 3` in client options +- **Redis Version**: Redis 6.0+ (for RESP3 and push notification support) +- **Push Notifications**: Server must support push notifications + +### **Client Support** +- βœ… `redis.Client` - Full support +- βœ… `redis.ClusterClient` - Full support +- βœ… `redis.SentinelClient` - Full support +- ❌ `redis.RingClient` - Not supported + +### **Network Requirements** +- **Endpoint Connectivity**: Client must be able to connect to new endpoints provided in MOVING notifications +- **TLS Compatibility**: Auto-detects appropriate endpoint type based on TLS configuration + +## ⚠️ Current Limitations + +### **Implementation Scope** +- **Connection-Level Operations**: Handoffs and timeouts are managed per-connection, not pool-wide +- **Single Pool Architecture**: No dual-pool implementation (contrary to some documentation) +- **MOVING Operations Only**: Only MOVING notifications trigger connection handoffs +- **No Slot Tracking**: MIGRATING/MIGRATED notifications only affect timeouts, not routing + +### **Configuration Constraints** +- **Auto-Calculated Defaults**: Many settings are calculated based on pool size and cannot be overridden +- **Worker Scaling**: Dynamic worker scaling is based on simple heuristics +- **Queue Management**: Handoff queue has timeout-based overflow handling + +### **Error Handling** +- **Handoff Failures**: Failed handoffs result in connection removal from pool +- **Notification Errors**: Invalid notifications are logged but don't stop processing +- **Timeout Handling**: Queue timeouts (5s) may drop handoff requests under extreme load + +## πŸ” Implementation Details + +### **Operation Tracking** +- **Composite Keys**: MOVING operations tracked with `(seqID, connID)` to handle duplicates +- **Deduplication**: Duplicate MOVING notifications for same operation are ignored +- **Connection Marking**: Connections marked for handoff with target endpoint and sequence ID + +### **Worker Management** +- **Dynamic Scaling**: Workers scale between MinWorkers and MaxWorkers based on load +- **Scale-Down Delay**: 2-second delay before checking if workers should be scaled down +- **Graceful Shutdown**: Workers complete current handoffs before shutting down + +### **Timeout Behavior** +- **Per-Connection**: Relaxed timeouts applied only to connections receiving notifications +- **Deadline Management**: Automatic timeout clearing with configurable post-handoff duration +- **Context Priority**: Context deadlines always take precedence over relaxed timeouts + +### **Push Notification Integration** +- **Protected Handlers**: Hitless upgrade handlers cannot be overwritten by application code +- **Automatic Registration**: All upgrade event handlers registered during initialization +- **Error Isolation**: Handler errors don't affect other notification processing + +## πŸ“Š Monitoring and Troubleshooting + +### **Logging Levels** +```go +// Configure logging verbosity +HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + LogLevel: 2, // 0=errors, 1=warnings, 2=info, 3=debug +} +``` + +- **Level 0 (Errors)**: Only critical errors (handoff failures, configuration errors) +- **Level 1 (Warnings)**: Default level, includes warnings and errors +- **Level 2 (Info)**: Handoff operations, worker scaling, operation tracking +- **Level 3 (Debug)**: Detailed notification processing, connection state changes + +### **Common Issues** + +#### **"RESP3 protocol required" Error** +```go +import "github.com/redis/go-redis/v9/hitless" + +// ❌ Wrong - will log error and disable hitless upgrades +client := redis.NewClient(&redis.Options{ + Protocol: 2, // RESP2 doesn't support push notifications + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + }, +}) + +// βœ… Correct - enables hitless upgrades +client := redis.NewClient(&redis.Options{ + Protocol: 3, // RESP3 required + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + }, +}) +``` + +#### **Handoff Queue Timeouts** +- **Symptom**: "handoff queue timeout after 5 seconds" errors +- **Cause**: Queue overflow under high load +- **Solution**: Increase `MaxWorkers` or `HandoffQueueSize` in configuration + +#### **Failed Connection Handoffs** +- **Symptom**: Connections removed from pool during MOVING operations +- **Cause**: Network connectivity issues to new endpoints +- **Solution**: Verify network connectivity and endpoint reachability + +### **Performance Considerations** +- **Worker Count**: Balance between responsiveness and resource usage +- **Queue Size**: Size based on expected burst load and worker capacity +- **Timeout Values**: Balance between resilience and responsiveness +- **Pool Size Impact**: Worker defaults scale with pool size automatically + +## πŸš€ Getting Started + +1. **Enable RESP3**: Set `Protocol: 3` in your client options +2. **Import Package**: `import "github.com/redis/go-redis/v9/hitless"` +3. **Enable Feature**: Set `Enabled: hitless.MaintNotificationsEnabled` in `HitlessUpgradeConfig` +4. **Optional Config**: Customize other `HitlessUpgradeConfig` settings if needed +5. **Test**: Verify with Redis cluster maintenance operations + +The system will automatically handle all Redis cluster upgrade notifications without requiring any changes to your application code. + +### **CLIENT MAINT_NOTIFICATIONS Command** + +The hitless upgrade system uses the standard go-redis command interface: + +```go +// The CLIENT MAINT_NOTIFICATIONS command is sent automatically during client initialization +// You can also send it manually if needed: +result := client.ClientMaintNotifications(ctx, true, "internal-ip", 30) +if err := result.Err(); err != nil { + log.Printf("Failed to enable maintenance notifications: %v", err) +} + +// Disable maintenance notifications +result = client.ClientMaintNotifications(ctx, false, "", 0) +``` + +This command follows standard go-redis patterns and supports: +- βœ… **Pipeline compatibility** - Can be used in pipelines and transactions +- βœ… **Standard error handling** - Uses go-redis error handling and retries +- βœ… **Monitoring integration** - Appears in standard metrics and logging +- βœ… **Testing support** - Can be easily mocked and tested + +## πŸ”’ Type Safety and Validation + +### **Enum-Based Configuration** + +The hitless upgrade system uses type-safe enums to prevent configuration errors: + +```go +import "github.com/redis/go-redis/v9/hitless" + +// βœ… Type-safe configuration - compile-time validation +config := &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, // Enum prevents typos + EndpointType: hitless.EndpointTypeInternalIP, // Enum prevents invalid values +} + +// ❌ This would cause a compilation error: +// config.Enabled = "enable" // Wrong! Not a valid enum value +// config.EndpointType = "internal" // Wrong! Not a valid enum value +``` + +### **Configuration Validation** + +All configuration values are validated at runtime: + +```go +// Valid configurations +hitless.MaintNotificationsDisabled.IsValid() // true +hitless.MaintNotificationsEnabled.IsValid() // true +hitless.MaintNotificationsAuto.IsValid() // true +hitless.EndpointTypeAuto.IsValid() // true +hitless.EndpointTypeInternalIP.IsValid() // true + +// Invalid configurations would return false +// Custom validation ensures only supported values are accepted +``` + +### **Auto-Detection Features** + +```go +import "github.com/redis/go-redis/v9/hitless" + +// Auto-detect endpoint type based on connection settings +endpointType := hitless.DetectEndpointType("10.0.0.1:6379", false) // Returns EndpointTypeInternalIP +endpointType = hitless.DetectEndpointType("redis.example.com:6379", true) // Returns EndpointTypeExternalFQDN + +// Use in configuration +config := &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsAuto, + EndpointType: hitless.EndpointTypeAuto, // Will auto-detect during initialization +} +``` diff --git a/hitless/config.go b/hitless/config.go new file mode 100644 index 0000000000..06b464a372 --- /dev/null +++ b/hitless/config.go @@ -0,0 +1,355 @@ +package hitless + +import ( + "net" + "runtime" + "time" + + "github.com/redis/go-redis/v9/internal/util" +) + +// MaintNotificationsMode represents the maintenance notifications mode +type MaintNotificationsMode string + +// Constants for maintenance push notifications modes +const ( + MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command + MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error + MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error +) + +// IsValid returns true if the maintenance notifications mode is valid +func (m MaintNotificationsMode) IsValid() bool { + switch m { + case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto: + return true + default: + return false + } +} + +// String returns the string representation of the mode +func (m MaintNotificationsMode) String() string { + return string(m) +} + +// EndpointType represents the type of endpoint to request in MOVING notifications +type EndpointType string + +// Constants for endpoint types +const ( + EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection + EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address + EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN + EndpointTypeExternalIP EndpointType = "external-ip" // External IP address + EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN + EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config) +) + +// IsValid returns true if the endpoint type is valid +func (e EndpointType) IsValid() bool { + switch e { + case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone: + return true + default: + return false + } +} + +// String returns the string representation of the endpoint type +func (e EndpointType) String() string { + return string(e) +} + +// Config provides configuration options for hitless upgrades. +type Config struct { + // Enabled controls how client maintenance notifications are handled. + // Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto + // Default: MaintNotificationsAuto + Enabled MaintNotificationsMode + + // EndpointType specifies the type of endpoint to request in MOVING notifications. + // Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + // EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone + // Default: EndpointTypeAuto + EndpointType EndpointType + + // RelaxedTimeout is the concrete timeout value to use during + // MIGRATING/FAILING_OVER states to accommodate increased latency. + // This applies to both read and write timeouts. + // Default: 30 seconds + RelaxedTimeout time.Duration + + // HandoffTimeout is the maximum time to wait for connection handoff to complete. + // If handoff takes longer than this, the old connection will be forcibly closed. + // Default: 15 seconds (matches server-side eviction timeout) + HandoffTimeout time.Duration + + // MinWorkers is the minimum number of worker goroutines for processing handoff requests. + // The processor starts with this number of workers and scales down to this level when idle. + // If zero, defaults to max(1, PoolSize/25) to be proportional to connection pool size. + // + // Default: max(1, PoolSize/25) + MinWorkers int + + // MaxWorkers is the maximum number of worker goroutines for processing handoff requests. + // The processor will scale up to this number when under load. + // If zero, defaults to max(MinWorkers*4, PoolSize/5) to handle bursts effectively. + // + // Default: max(MinWorkers*4, PoolSize/5) + MaxWorkers int + + // HandoffQueueSize is the size of the buffered channel used to queue handoff requests. + // If the queue is full, new handoff requests will be rejected. + // + // Default: 10x max workers, but never more than pool size + HandoffQueueSize int + + // PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection + // after a handoff completes. This provides additional resilience during cluster transitions. + // Default: 2 * RelaxedTimeout + PostHandoffRelaxedDuration time.Duration + + // ScaleDownDelay is the delay before checking if workers should be scaled down. + // This prevents expensive checks on every handoff completion and avoids rapid scaling cycles. + // Default: 2 seconds + ScaleDownDelay time.Duration + + // LogLevel controls the verbosity of hitless upgrade logging. + // 0 = errors only, 1 = warnings, 2 = info, 3 = debug + // Default: 1 (warnings) + LogLevel int +} + +func (c *Config) IsEnabled() bool { + return c != nil && c.Enabled != MaintNotificationsDisabled +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() *Config { + return &Config{ + Enabled: MaintNotificationsAuto, // Enable by default for Redis Cloud + EndpointType: EndpointTypeAuto, // Auto-detect based on connection + RelaxedTimeout: 30 * time.Second, + HandoffTimeout: 15 * time.Second, + MinWorkers: 0, // Auto-calculated based on pool size + MaxWorkers: 0, // Auto-calculated based on pool size + HandoffQueueSize: 0, // Auto-calculated based on max workers + PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout + ScaleDownDelay: 2 * time.Second, + LogLevel: 1, + } +} + +// Validate checks if the configuration is valid. +func (c *Config) Validate() error { + if c.RelaxedTimeout <= 0 { + return ErrInvalidRelaxedTimeout + } + if c.HandoffTimeout <= 0 { + return ErrInvalidHandoffTimeout + } + // Validate worker configuration + // Check raw values before defaults are applied + if c.MinWorkers <= 0 { + return ErrInvalidHandoffWorkers + } + if c.MinWorkers > 0 && c.MaxWorkers > 0 && c.MaxWorkers < c.MinWorkers { + return ErrInvalidWorkerRange + } + // HandoffQueueSize validation - allow 0 for auto-calculation + if c.HandoffQueueSize < 0 { + return ErrInvalidHandoffQueueSize + } + if c.PostHandoffRelaxedDuration < 0 { + return ErrInvalidPostHandoffRelaxedDuration + } + if c.LogLevel < 0 || c.LogLevel > 3 { + return ErrInvalidLogLevel + } + + // Validate Enabled (maintenance notifications mode) + if !c.Enabled.IsValid() { + return ErrInvalidMaintNotifications + } + + // Validate EndpointType + if !c.EndpointType.IsValid() { + return ErrInvalidEndpointType + } + + return nil +} + +// ApplyDefaults applies default values to any zero-value fields in the configuration. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaults() *Config { + return c.ApplyDefaultsWithPoolSize(0) +} + +// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration, +// using the provided pool size to calculate worker defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { + if c == nil { + return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize) + } + + defaults := DefaultConfig() + result := &Config{} + + // Apply defaults for enum fields (empty/zero means not set) + if c.Enabled == "" { + result.Enabled = defaults.Enabled + } else { + result.Enabled = c.Enabled + } + + if c.EndpointType == "" { + result.EndpointType = defaults.EndpointType + } else { + result.EndpointType = c.EndpointType + } + + // Apply defaults for duration fields (zero means not set) + if c.RelaxedTimeout <= 0 { + result.RelaxedTimeout = defaults.RelaxedTimeout + } else { + result.RelaxedTimeout = c.RelaxedTimeout + } + + if c.HandoffTimeout <= 0 { + result.HandoffTimeout = defaults.HandoffTimeout + } else { + result.HandoffTimeout = c.HandoffTimeout + } + + // Apply defaults for integer fields (zero means not set) + if c.HandoffQueueSize <= 0 { + result.HandoffQueueSize = defaults.HandoffQueueSize + } else { + result.HandoffQueueSize = c.HandoffQueueSize + } + + // Copy worker configuration + result.MinWorkers = c.MinWorkers + result.MaxWorkers = c.MaxWorkers + + // Apply worker defaults based on pool size + result.applyWorkerDefaults(poolSize) + + // Apply queue size defaults based on max workers, capped by pool size + if c.HandoffQueueSize <= 0 { + // Queue size: 10x max workers, but never more than pool size + workerBasedSize := result.MaxWorkers * 10 + result.HandoffQueueSize = util.Min(workerBasedSize, poolSize) + } else { + result.HandoffQueueSize = c.HandoffQueueSize + } + + if c.PostHandoffRelaxedDuration <= 0 { + result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2 + } else { + result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration + } + + if c.ScaleDownDelay <= 0 { + result.ScaleDownDelay = defaults.ScaleDownDelay + } else { + result.ScaleDownDelay = c.ScaleDownDelay + } + + // LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set + // We'll use the provided value as-is, since 0 is valid + result.LogLevel = c.LogLevel + + return result +} + +// Clone creates a deep copy of the configuration. +func (c *Config) Clone() *Config { + if c == nil { + return DefaultConfig() + } + + return &Config{ + Enabled: c.Enabled, + EndpointType: c.EndpointType, + RelaxedTimeout: c.RelaxedTimeout, + HandoffTimeout: c.HandoffTimeout, + MinWorkers: c.MinWorkers, + MaxWorkers: c.MaxWorkers, + HandoffQueueSize: c.HandoffQueueSize, + PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration, + ScaleDownDelay: c.ScaleDownDelay, + LogLevel: c.LogLevel, + } +} + +// applyWorkerDefaults calculates and applies worker defaults based on pool size +func (c *Config) applyWorkerDefaults(poolSize int) { + // Calculate defaults based on pool size + if poolSize <= 0 { + poolSize = 10 * runtime.GOMAXPROCS(0) + } + + // MinWorkers: max(1, poolSize/25) - conservative baseline + if c.MinWorkers == 0 { + c.MinWorkers = util.Max(1, poolSize/25) + } + + // MaxWorkers: max(MinWorkers*4, poolSize/5) - handle bursts effectively + if c.MaxWorkers == 0 { + c.MaxWorkers = util.Max(c.MinWorkers*4, poolSize/5) + } + + // Ensure MaxWorkers >= MinWorkers + if c.MaxWorkers < c.MinWorkers { + c.MaxWorkers = c.MinWorkers + } +} + +// DetectEndpointType automatically detects the appropriate endpoint type +// based on the connection address and TLS configuration. +func DetectEndpointType(addr string, tlsEnabled bool) EndpointType { + // Parse the address to determine if it's an IP or hostname + isPrivate := isPrivateIP(addr) + + var endpointType EndpointType + + if tlsEnabled { + // TLS requires FQDN for certificate validation + if isPrivate { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } + } else { + // No TLS, can use IP addresses + if isPrivate { + endpointType = EndpointTypeInternalIP + } else { + endpointType = EndpointTypeExternalIP + } + } + + return endpointType +} + +// isPrivateIP checks if the given address is in a private IP range. +func isPrivateIP(addr string) bool { + // Extract host from "host:port" format + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr // Assume no port + } + + ip := net.ParseIP(host) + if ip == nil { + return false // Not an IP address (likely hostname) + } + + // Check for private/loopback ranges + return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() +} diff --git a/hitless/config_test.go b/hitless/config_test.go new file mode 100644 index 0000000000..f07385ac64 --- /dev/null +++ b/hitless/config_test.go @@ -0,0 +1,332 @@ +package hitless + +import ( + "context" + "net" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/util" +) + +func TestConfig(t *testing.T) { + t.Run("DefaultConfig", func(t *testing.T) { + config := DefaultConfig() + + // MinWorkers and MaxWorkers should be 0 in default config (auto-calculated) + if config.MinWorkers != 0 { + t.Errorf("Expected MinWorkers to be 0 (auto-calculated), got %d", config.MinWorkers) + } + if config.MaxWorkers != 0 { + t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers) + } + + // HandoffQueueSize should be 0 in default config (auto-calculated) + if config.HandoffQueueSize != 0 { + t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize) + } + + if config.RelaxedTimeout != 30*time.Second { + t.Errorf("Expected RelaxedTimeout to be 30s, got %v", config.RelaxedTimeout) + } + + if config.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout) + } + + if config.PostHandoffRelaxedDuration != 10*time.Second { + t.Errorf("Expected PostHandoffRelaxedDuration to be 10s, got %v", config.PostHandoffRelaxedDuration) + } + }) + + t.Run("ConfigValidation", func(t *testing.T) { + // Valid config with applied defaults + config := DefaultConfig().ApplyDefaults() + if err := config.Validate(); err != nil { + t.Errorf("Default config with applied defaults should be valid: %v", err) + } + + // Invalid worker configuration (MinWorkers is 0) + config = &Config{ + RelaxedTimeout: 30 * time.Second, + HandoffTimeout: 15 * time.Second, + HandoffQueueSize: 100, + PostHandoffRelaxedDuration: 10 * time.Second, + LogLevel: 1, + // MinWorkers is 0, should be invalid + } + if err := config.Validate(); err != ErrInvalidHandoffWorkers { + t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err) + } + + // Invalid worker range (MaxWorkers < MinWorkers) + config = DefaultConfig() + config.MinWorkers = 5 + config.MaxWorkers = 2 + if err := config.Validate(); err != ErrInvalidWorkerRange { + t.Errorf("Expected ErrInvalidWorkerRange, got %v", err) + } + + // Invalid HandoffQueueSize + config = DefaultConfig().ApplyDefaults() + config.HandoffQueueSize = -1 + if err := config.Validate(); err != ErrInvalidHandoffQueueSize { + t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err) + } + + // Invalid PostHandoffRelaxedDuration + config = DefaultConfig().ApplyDefaults() + config.PostHandoffRelaxedDuration = -1 * time.Second + if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration { + t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err) + } + }) + + t.Run("ConfigClone", func(t *testing.T) { + original := DefaultConfig() + original.MinWorkers = 5 + original.MaxWorkers = 20 + original.HandoffQueueSize = 200 + + cloned := original.Clone() + + if cloned.MinWorkers != 5 { + t.Errorf("Expected cloned MinWorkers to be 5, got %d", cloned.MinWorkers) + } + + if cloned.MaxWorkers != 20 { + t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers) + } + + if cloned.HandoffQueueSize != 200 { + t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize) + } + + // Modify original to ensure clone is independent + original.MinWorkers = 2 + if cloned.MinWorkers != 5 { + t.Error("Clone should be independent of original") + } + }) +} + +func TestApplyDefaults(t *testing.T) { + t.Run("NilConfig", func(t *testing.T) { + var config *Config + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // With nil config, should get default config with auto-calculated workers + if result.MinWorkers <= 0 { + t.Errorf("Expected MinWorkers to be > 0 after applying defaults, got %d", result.MinWorkers) + } + if result.MaxWorkers <= 0 { + t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers) + } + if result.MaxWorkers < result.MinWorkers { + t.Errorf("Expected MaxWorkers (%d) >= MinWorkers (%d)", result.MaxWorkers, result.MinWorkers) + } + + // HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size) + workerBasedSize := result.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize) + } + }) + + t.Run("PartialConfig", func(t *testing.T) { + config := &Config{ + MinWorkers: 3, // Set this field explicitly + MaxWorkers: 12, // Set this field explicitly + // Leave other fields as zero values + } + + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Should keep the explicitly set values + if result.MinWorkers != 3 { + t.Errorf("Expected MinWorkers to be 3 (explicitly set), got %d", result.MinWorkers) + } + if result.MaxWorkers != 12 { + t.Errorf("Expected MaxWorkers to be 12 (explicitly set), got %d", result.MaxWorkers) + } + + // Should apply default for unset fields (auto-calculated queue size, capped by pool size) + workerBasedSize := result.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize) + } + + if result.RelaxedTimeout != 30*time.Second { + t.Errorf("Expected RelaxedTimeout to be 30s (default), got %v", result.RelaxedTimeout) + } + + if result.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout) + } + }) + + t.Run("ZeroValues", func(t *testing.T) { + config := &Config{ + MinWorkers: 0, // Zero value should get auto-calculated defaults + MaxWorkers: 0, // Zero value should get auto-calculated defaults + HandoffQueueSize: 0, // Zero value should get default + RelaxedTimeout: 0, // Zero value should get default + LogLevel: 0, // Zero is valid for LogLevel (errors only) + } + + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Zero values should get auto-calculated defaults + if result.MinWorkers <= 0 { + t.Errorf("Expected MinWorkers to be > 0 (auto-calculated), got %d", result.MinWorkers) + } + if result.MaxWorkers <= 0 { + t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers) + } + + // HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size) + workerBasedSize := result.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize) + } + + if result.RelaxedTimeout != 30*time.Second { + t.Errorf("Expected RelaxedTimeout to be 30s (default), got %v", result.RelaxedTimeout) + } + + // LogLevel 0 should be preserved (it's a valid value) + if result.LogLevel != 0 { + t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel) + } + }) +} + +func TestProcessorWithConfig(t *testing.T) { + t.Run("ProcessorUsesConfigValues", func(t *testing.T) { + config := &Config{ + MinWorkers: 2, + MaxWorkers: 5, + HandoffQueueSize: 50, + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 5 * time.Second, + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + // The processor should be created successfully with custom config + if processor == nil { + t.Error("Processor should be created with custom config") + } + }) + + t.Run("ProcessorWithPartialConfig", func(t *testing.T) { + config := &Config{ + MinWorkers: 3, // Only set worker fields + MaxWorkers: 7, + // Other fields will get defaults + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + // Should work with partial config (defaults applied) + if processor == nil { + t.Error("Processor should be created with partial config") + } + }) + + t.Run("ProcessorWithNilConfig", func(t *testing.T) { + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil) + defer processor.Shutdown(context.Background()) + + // Should use default config when nil is passed + if processor == nil { + t.Error("Processor should be created with nil config (using defaults)") + } + }) +} + +func TestIntegrationWithApplyDefaults(t *testing.T) { + t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) { + // Create a partial config with only some fields set + partialConfig := &Config{ + MinWorkers: 3, // Custom value + MaxWorkers: 7, // Custom value + LogLevel: 2, // Custom value + // Other fields left as zero values - should get defaults + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + // Create processor - should apply defaults to missing fields + processor := NewRedisConnectionProcessor(3, baseDialer, partialConfig, nil) + defer processor.Shutdown(context.Background()) + + // Processor should be created successfully + if processor == nil { + t.Error("Processor should be created with partial config") + } + + // Test that the ApplyDefaults method worked correctly by creating the same config + // and applying defaults manually + expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Should preserve custom values + if expectedConfig.MinWorkers != 3 { + t.Errorf("Expected MinWorkers to be 3, got %d", expectedConfig.MinWorkers) + } + + if expectedConfig.MaxWorkers != 7 { + t.Errorf("Expected MaxWorkers to be 7, got %d", expectedConfig.MaxWorkers) + } + + if expectedConfig.LogLevel != 2 { + t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel) + } + + // Should apply defaults for missing fields (auto-calculated queue size, capped by pool size) + workerBasedSize := expectedConfig.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if expectedConfig.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, expectedConfig.HandoffQueueSize) + } + + if expectedConfig.RelaxedTimeout != 30*time.Second { + t.Errorf("Expected RelaxedTimeout to be 30s (default), got %v", expectedConfig.RelaxedTimeout) + } + + if expectedConfig.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout) + } + + if expectedConfig.PostHandoffRelaxedDuration != 10*time.Second { + t.Errorf("Expected PostHandoffRelaxedDuration to be 10s (default), got %v", expectedConfig.PostHandoffRelaxedDuration) + } + }) +} diff --git a/hitless/errors.go b/hitless/errors.go new file mode 100644 index 0000000000..2de2b358cd --- /dev/null +++ b/hitless/errors.go @@ -0,0 +1,68 @@ +package hitless + +import ( + "errors" + "fmt" +) + +// Configuration errors +var ( + ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0") + ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0") + ErrInvalidHandoffWorkers = errors.New("hitless: MinWorkers must be greater than 0") + ErrInvalidWorkerRange = errors.New("hitless: MaxWorkers must be greater than or equal to MinWorkers") + ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0") + ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0") + ErrInvalidLogLevel = errors.New("hitless: log level must be between 0 and 3") + ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type") + ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')") + ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached") +) + +// Integration errors +var ( + ErrInvalidClient = errors.New("hitless: invalid client type") +) + +// Handoff errors +var ( + ErrHandoffInProgress = errors.New("hitless: handoff already in progress") + ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress") + ErrConnectionFailed = errors.New("hitless: failed to establish new connection") +) + +// Dead error variables removed - unused in simplified architecture + +// Notification errors +var ( + ErrInvalidNotification = errors.New("hitless: invalid notification format") +) + +// Dead error variables removed - unused in simplified architecture + +// HandoffError represents an error that occurred during connection handoff. +type HandoffError struct { + Operation string + Endpoint string + Cause error +} + +func (e *HandoffError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause) + } + return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint) +} + +func (e *HandoffError) Unwrap() error { + return e.Cause +} + +// NewHandoffError creates a new HandoffError. +func NewHandoffError(operation, endpoint string, cause error) *HandoffError { + return &HandoffError{ + Operation: operation, + Endpoint: endpoint, + Cause: cause, + } +} diff --git a/hitless/example_config_usage.go b/hitless/example_config_usage.go new file mode 100644 index 0000000000..091c5d0e42 --- /dev/null +++ b/hitless/example_config_usage.go @@ -0,0 +1,75 @@ +package hitless + +import ( + "time" +) + +// ExampleCustomConfig shows how to create a custom hitless configuration +func ExampleCustomConfig() *Config { + return &Config{ + Enabled: MaintNotificationsEnabled, + EndpointType: EndpointTypeInternalIP, + RelaxedTimeout: 45 * time.Second, + RelaxedTimeoutDuration: 45 * time.Second, + HandoffTimeout: 20 * time.Second, + MinWorkers: 5, // Minimum workers for baseline processing + MaxWorkers: 20, // Maximum workers for high-throughput scenarios + HandoffQueueSize: 200, // Larger queue for burst handling + LogLevel: 2, // Info level logging + } +} + +// ExampleLowResourceConfig shows a configuration for resource-constrained environments +func ExampleLowResourceConfig() *Config { + return &Config{ + Enabled: MaintNotificationsEnabled, + EndpointType: EndpointTypeInternalIP, + RelaxedTimeout: 20 * time.Second, + RelaxedTimeoutDuration: 20 * time.Second, + HandoffTimeout: 10 * time.Second, + MinWorkers: 1, // Minimum workers to save resources + MaxWorkers: 3, // Low maximum for resource-constrained environments + HandoffQueueSize: 25, // Smaller queue + LogLevel: 1, // Warning level logging only + } +} + +// ExampleHighThroughputConfig shows a configuration for high-throughput scenarios +func ExampleHighThroughputConfig() *Config { + return &Config{ + Enabled: MaintNotificationsEnabled, + EndpointType: EndpointTypeInternalIP, + RelaxedTimeout: 60 * time.Second, + RelaxedTimeoutDuration: 60 * time.Second, + HandoffTimeout: 30 * time.Second, + MinWorkers: 10, // High baseline for consistent performance + MaxWorkers: 30, // Many workers for parallel processing + HandoffQueueSize: 500, // Large queue for burst handling + LogLevel: 3, // Debug level logging for monitoring + } +} + +// ExamplePartialConfig shows how partial configuration works with automatic defaults +func ExamplePartialConfig() *Config { + // Only specify the fields you want to customize + // Other fields will automatically get default values when ApplyDefaults() is called + return &Config{ + Enabled: MaintNotificationsEnabled, + MinWorkers: 3, // Custom minimum worker count + MaxWorkers: 15, // Custom maximum worker count + LogLevel: 2, // Info level logging + // HandoffQueueSize will get default value (100) + // RelaxedTimeout will get default value (30s) + // HandoffTimeout will get default value (15s) + // etc. + } +} + +// ExampleMinimalConfig shows the most minimal configuration +func ExampleMinimalConfig() *Config { + // Just enable hitless upgrades, everything else gets defaults + return &Config{ + Enabled: MaintNotificationsEnabled, + // All other fields will get default values automatically + } +} diff --git a/hitless/hitless_manager.go b/hitless/hitless_manager.go new file mode 100644 index 0000000000..b22114cb98 --- /dev/null +++ b/hitless/hitless_manager.go @@ -0,0 +1,212 @@ +package hitless + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/interfaces" +) + +// MovingOperationKey provides a unique key for tracking MOVING operations +// that combines sequence ID with connection identifier to handle duplicate +// sequence IDs across multiple connections to the same node. +type MovingOperationKey struct { + SeqID int64 // Sequence ID from MOVING notification + ConnID uint64 // Unique connection identifier +} + +// String returns a string representation of the key for debugging +func (k MovingOperationKey) String() string { + return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID) +} + +// HitlessManager provides a simplified hitless upgrade functionality. +type HitlessManager struct { + mu sync.RWMutex + + client interfaces.ClientInterface + config *Config + options interfaces.OptionsInterface + + // MOVING operation tracking with composite keys + activeMovingOps map[MovingOperationKey]*MovingOperation + + closed bool +} + +// MovingOperation tracks an active MOVING operation. +type MovingOperation struct { + SeqID int64 + NewEndpoint string + StartTime time.Time + Deadline time.Time +} + +// NewHitlessManager creates a new simplified hitless manager. +func NewHitlessManager(client interfaces.ClientInterface, config *Config) (*HitlessManager, error) { + if client == nil { + return nil, ErrInvalidClient + } + + hm := &HitlessManager{ + client: client, + options: client.GetOptions(), + config: config.Clone(), + activeMovingOps: make(map[MovingOperationKey]*MovingOperation), + } + + // Set up push notification handling + if err := hm.setupPushNotifications(); err != nil { + return nil, err + } + + return hm, nil +} + +// setupPushNotifications sets up push notification handling by registering with the client's processor. +func (hm *HitlessManager) setupPushNotifications() error { + processor := hm.client.GetPushProcessor() + if processor == nil { + return ErrInvalidClient // Client doesn't support push notifications + } + + // Create our notification handler + handler := &NotificationHandler{manager: hm} + + // Register handlers for all hitless upgrade notifications with the client's processor + if err := processor.RegisterHandler("MOVING", handler, true); err != nil { + return err + } + if err := processor.RegisterHandler("MIGRATING", handler, true); err != nil { + return err + } + if err := processor.RegisterHandler("MIGRATED", handler, true); err != nil { + return err + } + if err := processor.RegisterHandler("FAILING_OVER", handler, true); err != nil { + return err + } + if err := processor.RegisterHandler("FAILED_OVER", handler, true); err != nil { + return err + } + + return nil +} + +// StartMovingOperationWithConnID starts a new MOVING operation with a specific connection ID. +func (hm *HitlessManager) StartMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error { + hm.mu.Lock() + defer hm.mu.Unlock() + + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Check for duplicate operation + if _, exists := hm.activeMovingOps[key]; exists { + // Duplicate MOVING notification, ignore + internal.Logger.Printf(ctx, "Duplicate MOVING operation ignored: %s", key.String()) + return nil + } + + // Create MOVING operation record + movingOp := &MovingOperation{ + SeqID: seqID, + NewEndpoint: newEndpoint, + StartTime: time.Now(), + Deadline: deadline, + } + hm.activeMovingOps[key] = movingOp + + return nil +} + +// CompleteOperationWithConnID completes a MOVING operation with a specific connection ID. +func (hm *HitlessManager) CompleteOperationWithConnID(seqID int64, connID uint64) { + hm.mu.Lock() + defer hm.mu.Unlock() + + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Remove from active operations + delete(hm.activeMovingOps, key) +} + +// GetActiveMovingOperations returns active operations with composite keys. +func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation { + hm.mu.RLock() + defer hm.mu.RUnlock() + + result := make(map[MovingOperationKey]*MovingOperation) + for key, op := range hm.activeMovingOps { + result[key] = &MovingOperation{ + SeqID: op.SeqID, + NewEndpoint: op.NewEndpoint, + StartTime: op.StartTime, + Deadline: op.Deadline, + } + } + return result +} + +// IsHandoffInProgress returns true if any handoff is in progress. +func (hm *HitlessManager) IsHandoffInProgress() bool { + hm.mu.RLock() + defer hm.mu.RUnlock() + return len(hm.activeMovingOps) > 0 +} + +// Close closes the hitless manager. +func (hm *HitlessManager) Close() error { + hm.mu.Lock() + defer hm.mu.Unlock() + + if hm.closed { + return nil + } + + hm.closed = true + return nil +} + +// GetState returns current state +func (hm *HitlessManager) GetState() State { + hm.mu.RLock() + defer hm.mu.RUnlock() + + if len(hm.activeMovingOps) > 0 { + return StateMoving + } + return StateIdle +} + +// GetConfig returns the hitless manager configuration. +func (hm *HitlessManager) GetConfig() *Config { + hm.mu.RLock() + defer hm.mu.RUnlock() + return hm.config.Clone() +} + +// CreateConnectionProcessor creates a connection processor with this manager already set. +// Returns the processor as the shared interface type. +func (hm *HitlessManager) CreateConnectionProcessor(protocol int, baseDialer func(context.Context, string, string) (net.Conn, error)) *RedisConnectionProcessor { + // Get pool size from client options for better worker defaults + poolSize := 0 + if hm.options != nil { + poolSize = hm.options.GetPoolSize() + } + + processor := NewRedisConnectionProcessorWithPoolSize(protocol, baseDialer, hm.config, hm, poolSize) + + return processor +} diff --git a/hitless/notification_handler.go b/hitless/notification_handler.go new file mode 100644 index 0000000000..2bc874c8d6 --- /dev/null +++ b/hitless/notification_handler.go @@ -0,0 +1,217 @@ +package hitless + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// NotificationHandler handles push notifications for the simplified manager. +type NotificationHandler struct { + manager *HitlessManager +} + +// HandlePushNotification processes push notifications. +func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) == 0 { + return ErrInvalidNotification + } + + notificationType, ok := notification[0].(string) + if !ok { + return ErrInvalidNotification + } + + switch notificationType { + case "MOVING": + return snh.handleMoving(ctx, handlerCtx, notification) + case "MIGRATING": + return snh.handleMigrating(ctx, handlerCtx, notification) + case "MIGRATED": + return snh.handleMigrated(ctx, handlerCtx, notification) + case "FAILING_OVER": + return snh.handleFailingOver(ctx, handlerCtx, notification) + case "FAILED_OVER": + return snh.handleFailedOver(ctx, handlerCtx, notification) + default: + // Ignore other notification types (e.g., pub/sub messages) + return nil + } +} + +// handleMoving processes MOVING notifications. +// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff +func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + return ErrInvalidNotification + } + seqIDStr, ok := notification[1].(string) + if !ok { + return ErrInvalidNotification + } + + seqID, err := strconv.ParseInt(seqIDStr, 10, 64) + if err != nil { + return ErrInvalidNotification + } + + // Extract timeS + timeSStr, ok := notification[2].(string) + if !ok { + return ErrInvalidNotification + } + + timeS, err := strconv.ParseInt(timeSStr, 10, 64) + if err != nil { + return ErrInvalidNotification + } + + newEndpoint := "" + ok = false + if len(notification) > 3 { + // Extract new endpoint + newEndpoint, ok = notification[3].(string) + if !ok { + return ErrInvalidNotification + } + } + + // Get the connection that received this notification + conn := handlerCtx.Conn + if conn == nil { + return ErrInvalidNotification + } + + // Type assert to get the underlying pool connection + var poolConn *pool.Conn + if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok { + poolConn = connAdapter.GetPoolConn() + } else if pc, ok := conn.(*pool.Conn); ok { + poolConn = pc + } else { + return ErrInvalidNotification + } + + // TODO(hitless): if newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds + + // Mark the connection for handoff + if err := poolConn.MarkForHandoff(newEndpoint, seqID); err != nil { + // Connection is already marked for handoff, which is acceptable + // This can happen if multiple MOVING notifications are received for the same connection + return nil + } + + // Optionally track in hitless manager for monitoring/debugging + if snh.manager != nil { + connID := poolConn.GetID() + deadline := time.Now().Add(time.Duration(timeS) * time.Second) + + // Track the operation (ignore errors since this is optional) + _ = snh.manager.StartMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID) + } else { + return fmt.Errorf("hitless: manager not initialized") + } + + return nil +} + +// handleMigrating processes MIGRATING notifications. +func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // MIGRATING notifications indicate that a connection is about to be migrated + // Apply relaxed timeouts to the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleMigrated processes MIGRATED notifications. +func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // MIGRATED notifications indicate that a connection migration has completed + // Restore normal timeouts for the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + connAdapter.ClearRelaxedTimeout() + return nil +} + +// handleFailingOver processes FAILING_OVER notifications. +func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // FAILING_OVER notifications indicate that a connection is about to failover + // Apply relaxed timeouts to the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleFailedOver processes FAILED_OVER notifications. +func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // FAILED_OVER notifications indicate that a connection failover has completed + // Restore normal timeouts for the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + connAdapter.ClearRelaxedTimeout() + return nil +} diff --git a/hitless/redis_connection_processor.go b/hitless/redis_connection_processor.go new file mode 100644 index 0000000000..b515805835 --- /dev/null +++ b/hitless/redis_connection_processor.go @@ -0,0 +1,624 @@ +package hitless + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/internal/proto" +) + +// HitlessManagerInterface defines the interface for completing handoff operations +type HitlessManagerInterface interface { + CompleteOperationWithConnID(seqID int64, connID uint64) +} + +// HandoffRequest represents a request to handoff a connection to a new endpoint +type HandoffRequest struct { + Conn *pool.Conn + ConnID uint64 // Unique connection identifier + Endpoint string + SeqID int64 + Result chan HandoffResult + StopWorkerRequest bool + Pool pool.Pooler // Pool to remove connection from on failure +} + +// HandoffResult represents the result of a handoff operation +type HandoffResult struct { + Conn *pool.Conn + Err error +} + +// RedisConnectionProcessor implements interfaces.ConnectionProcessor for Redis-specific connection handling +// with hitless upgrade support. +type RedisConnectionProcessor struct { + // Protocol version (2 = RESP2, 3 = RESP3 with push notifications) + protocol int + + // Base dialer for creating connections to new endpoints during handoffs + baseDialer func(context.Context, string, string) (net.Conn, error) + + // Event-driven handoff support + handoffQueue chan HandoffRequest // Queue for handoff requests + shutdown chan struct{} // Shutdown signal + shutdownOnce sync.Once // Ensure clean shutdown + workerWg sync.WaitGroup // Track worker goroutines + + // Dynamic worker scaling + minWorkers int + maxWorkers int + currentWorkers int + scalingMu sync.Mutex + scaleLevel int // 0=min, 1=max + + // Scale down optimization + scaleDownTimer *time.Timer + scaleDownMu sync.Mutex + lastCompletionTime time.Time + scaleDownDelay time.Duration + + // Simple state tracking + pending sync.Map // map[uint64]int64 (connID -> seqID) + + // Configuration for the processor + config *Config + + // Hitless manager for operation completion tracking + hitlessManager HitlessManagerInterface + + // Pool interface for removing connections on handoff failure + pool pool.Pooler +} + +// NewRedisConnectionProcessor creates a new Redis connection processor +func NewRedisConnectionProcessor(protocol int, baseDialer func(context.Context, string, string) (net.Conn, error), config *Config, hitlessManager HitlessManagerInterface) *RedisConnectionProcessor { + return NewRedisConnectionProcessorWithPoolSize(protocol, baseDialer, config, hitlessManager, 0) +} + +// NewRedisConnectionProcessorWithPoolSize creates a new Redis connection processor with pool size for better worker defaults +func NewRedisConnectionProcessorWithPoolSize(protocol int, baseDialer func(context.Context, string, string) (net.Conn, error), config *Config, hitlessManager HitlessManagerInterface, poolSize int) *RedisConnectionProcessor { + // Apply defaults to any missing configuration fields, using pool size for worker calculations + config = config.ApplyDefaultsWithPoolSize(poolSize) + + rcp := &RedisConnectionProcessor{ + // Protocol version (2 = RESP2, 3 = RESP3 with push notifications) + protocol: protocol, + // baseDialer is used to create connections to new endpoints during handoffs + baseDialer: baseDialer, + // Note: CLIENT MAINT_NOTIFICATIONS is handled during client initialization + // handoffQueue is a buffered channel for queuing handoff requests + handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize), + // shutdown is a channel for signaling shutdown + shutdown: make(chan struct{}), + minWorkers: config.MinWorkers, + maxWorkers: config.MaxWorkers, + // Start with minimum workers + currentWorkers: config.MinWorkers, + scaleLevel: 0, // Start at minimum + config: config, + // Hitless manager for operation completion tracking + hitlessManager: hitlessManager, + scaleDownDelay: config.ScaleDownDelay, + } + + // Start worker goroutines at minimum level + rcp.startWorkers(rcp.minWorkers) + + return rcp +} + +// SetPool sets the pool interface for removing connections on handoff failure +func (rcp *RedisConnectionProcessor) SetPool(pooler pool.Pooler) { + rcp.pool = pooler +} + +// GetCurrentWorkers returns the current number of workers (for testing) +func (rcp *RedisConnectionProcessor) GetCurrentWorkers() int { + rcp.scalingMu.Lock() + defer rcp.scalingMu.Unlock() + return rcp.currentWorkers +} + +// GetScaleLevel returns the current scale level (for testing) +func (rcp *RedisConnectionProcessor) GetScaleLevel() int { + rcp.scalingMu.Lock() + defer rcp.scalingMu.Unlock() + return rcp.scaleLevel +} + +// log logs a message if the log level is appropriate +func (rcp *RedisConnectionProcessor) log(level int, message string) { + if rcp.config.LogLevel >= level { + internal.Logger.Printf(context.Background(), message) + } +} + +// IsHandoffPending returns true if the given connection has a pending handoff +func (rcp *RedisConnectionProcessor) IsHandoffPending(conn *pool.Conn) bool { + _, pending := rcp.pending.Load(conn) + return pending +} + +// ProcessConnectionOnGet is called when a connection is retrieved from the pool +func (rcp *RedisConnectionProcessor) ProcessConnectionOnGet(ctx context.Context, conn interface{}) error { + cn, ok := conn.(*pool.Conn) + if !ok { + return fmt.Errorf("hitless: expected *pool.Conn, got %T", conn) + } + + // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is + // in a handoff state at the moment. + + // Check if connection is usable (not in a handoff state) + if !cn.IsUsable() { + return ErrConnectionMarkedForHandoff + } + + // Check if connection is marked for handoff, which means it will be queued for handoff on put. + if cn.ShouldHandoff() { + return ErrConnectionMarkedForHandoff + } + + // Note: CLIENT MAINT_NOTIFICATIONS command is sent during client initialization + // in redis.go, so no need to send it here per connection + + return nil +} + +// ProcessConnectionOnPut is called when a connection is returned to the pool +func (rcp *RedisConnectionProcessor) ProcessConnectionOnPut(ctx context.Context, conn interface{}) (shouldPool bool, shouldRemove bool, err error) { + cn, ok := conn.(*pool.Conn) + if !ok { + return false, true, fmt.Errorf("hitless: expected *pool.Conn, got %T", conn) + } + + if cn.HasBufferedData() { + // Check for buffered data that might be push notifications + // Check if this might be push notification data + if rcp.protocol == 3 { + // For RESP3, peek at the reply type to check if it's a push notification + if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { + // Not a push notification or error peeking, remove connection + internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + return false, true, nil + } + // It's a push notification, allow pooling (client will handle it) + } else { + // For RESP2, any buffered data is unexpected + internal.Logger.Printf(ctx, "Conn has unread data, removing it") + return false, true, nil + } + } + + // first check if we should handoff for faster rejection + if cn.ShouldHandoff() { + // check pending handoff to not queue the same connection twice + _, hasPendingHandoff := rcp.pending.Load(cn) + if !hasPendingHandoff { + // Check for empty endpoint first (synchronous check) + if cn.GetHandoffEndpoint() == "" { + cn.ClearHandoffState() + } else { + if err := rcp.queueHandoff(cn); err != nil { + // Failed to queue handoff, remove the connection + internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err) + return false, true, nil // Don't pool, remove connection, no error to caller + } + cn.MarkQueuedForHandoff() + return true, false, nil + } + } + } + // Default: pool the connection + return true, false, nil +} + +// startWorkers starts the worker goroutines for processing handoff requests +func (rcp *RedisConnectionProcessor) startWorkers(count int) { + for i := 0; i < count; i++ { + rcp.workerWg.Add(1) + go rcp.handoffWorker() + } +} + +// scaleUpWorkers scales up workers when queue is full (single step: min β†’ max) +func (rcp *RedisConnectionProcessor) scaleUpWorkers() { + rcp.scalingMu.Lock() + defer rcp.scalingMu.Unlock() + + if rcp.scaleLevel >= 1 { + return // Already at maximum scale + } + + previousWorkers := rcp.currentWorkers + targetWorkers := rcp.maxWorkers + + // Ensure we don't go below current workers + if targetWorkers <= rcp.currentWorkers { + return + } + + additionalWorkers := targetWorkers - rcp.currentWorkers + rcp.startWorkers(additionalWorkers) + rcp.currentWorkers = targetWorkers + rcp.scaleLevel = 1 + + if rcp.config != nil && rcp.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: scaled up workers from %d to %d (max level) due to queue pressure", + previousWorkers, rcp.currentWorkers) + } +} + +// scaleDownWorkers returns to minimum worker count when queue is empty +func (rcp *RedisConnectionProcessor) scaleDownWorkers() { + rcp.scalingMu.Lock() + defer rcp.scalingMu.Unlock() + + if rcp.scaleLevel == 0 { + return // Already at minimum scale + } + + // Send stop worker requests to excess workers + excessWorkers := rcp.currentWorkers - rcp.minWorkers + previousWorkers := rcp.currentWorkers + + for i := 0; i < excessWorkers; i++ { + stopRequest := HandoffRequest{ + StopWorkerRequest: true, + } + + // Try to send stop request without blocking + select { + case rcp.handoffQueue <- stopRequest: + // Stop request sent successfully + default: + // Queue is full, worker will naturally exit when queue empties + break + } + } + + rcp.currentWorkers = rcp.minWorkers + rcp.scaleLevel = 0 + + if rcp.config != nil && rcp.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: scaling down workers from %d to %d (sent %d stop requests)", + previousWorkers, rcp.minWorkers, excessWorkers) + } +} + +// queueHandoffWithTimeout attempts to queue a handoff request with timeout and scaling +func (rcp *RedisConnectionProcessor) queueHandoffWithTimeout(request HandoffRequest, cn *pool.Conn) { + // First attempt - try immediate queuing + select { + case rcp.handoffQueue <- request: + return + case <-rcp.shutdown: + rcp.pending.Delete(cn) + return + default: + // Queue is full - log and attempt scaling + if rcp.config != nil && rcp.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(context.Background(), + "hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers", + len(rcp.handoffQueue), cap(rcp.handoffQueue)) + } + + // Scale up workers to handle the load + rcp.scaleUpWorkers() + } + + // TODO: reimplement? extract as config? + // Second attempt - try queuing with timeout of 2 seconds + timeout := time.NewTimer(2 * time.Second) + defer timeout.Stop() + + select { + case rcp.handoffQueue <- request: + // Queued successfully after timeout + if rcp.config != nil && rcp.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: handoff queued successfully after scaling workers") + } + return + case <-timeout.C: + // Timeout expired - drop the connection + err := errors.New("handoff queue timeout after 5 seconds") + rcp.pending.Delete(cn) + if rcp.config != nil && rcp.config.LogLevel >= 0 { // Error level + internal.Logger.Printf(context.Background(), err.Error()) + } + return + case <-rcp.shutdown: + rcp.pending.Delete(cn) + return + } +} + +// scheduleScaleDownCheck schedules a scale down check after a delay +// This is called after completing a handoff request to avoid expensive immediate checks +func (rcp *RedisConnectionProcessor) scheduleScaleDownCheck() { + rcp.scaleDownMu.Lock() + defer rcp.scaleDownMu.Unlock() + + // Update last completion time + rcp.lastCompletionTime = time.Now() + + // If timer already exists, reset it + if rcp.scaleDownTimer != nil { + rcp.scaleDownTimer.Reset(rcp.scaleDownDelay) + return + } + + // Create new timer + rcp.scaleDownTimer = time.AfterFunc(rcp.scaleDownDelay, func() { + rcp.performScaleDownCheck() + }) +} + +// performScaleDownCheck performs the actual scale down check +// This runs in a background goroutine after the delay +func (rcp *RedisConnectionProcessor) performScaleDownCheck() { + rcp.scaleDownMu.Lock() + defer rcp.scaleDownMu.Unlock() + + // Clear the timer since it has fired + rcp.scaleDownTimer = nil + + // Check if we should scale down + if rcp.shouldScaleDown() { + rcp.scaleDownWorkers() + } +} + +// shouldScaleDown checks if conditions are met for scaling down +// This is the expensive check that we want to minimize +func (rcp *RedisConnectionProcessor) shouldScaleDown() bool { + // Quick check: if we're already at minimum scale, no need to scale down + if rcp.scaleLevel == 0 { + return false + } + + // Quick check: if queue is not empty, don't scale down + if len(rcp.handoffQueue) > 0 { + return false + } + + // Expensive check: count pending handoffs + pendingCount := 0 + rcp.pending.Range(func(key, value interface{}) bool { + pendingCount++ + return pendingCount < 5 // Early exit if we find several pending + }) + + // Only scale down if no pending handoffs + return pendingCount == 0 +} + +// handoffWorker processes handoff requests from the queue +func (rcp *RedisConnectionProcessor) handoffWorker() { + defer rcp.workerWg.Done() + + for { + select { + case request := <-rcp.handoffQueue: + // Check if this is a stop worker request + if request.StopWorkerRequest { + if rcp.config != nil && rcp.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: worker received stop request, exiting") + } + return // Exit this worker + } + + rcp.processHandoffRequest(request) + case <-rcp.shutdown: + return + } + } +} + +// processHandoffRequest processes a single handoff request +func (rcp *RedisConnectionProcessor) processHandoffRequest(request HandoffRequest) { + // Safety check: ignore stop worker requests (should be handled in worker) + if request.StopWorkerRequest { + return + } + + // Remove from pending map + defer rcp.pending.Delete(request.Conn) + + // Perform the handoff + err := rcp.performConnectionHandoffWithPool(context.Background(), request.Conn, request.Pool) + + // If handoff failed, restore the handoff state for potential retry + if err != nil { + request.Conn.RestoreHandoffState() + internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err) + } + + // Schedule a scale down check after completing this handoff request + // This avoids expensive immediate checks and prevents rapid scaling cycles + rcp.scheduleScaleDownCheck() +} + +// queueHandoff queues a handoff request for processing +// if err is returned, connection will be removed from pool +func (rcp *RedisConnectionProcessor) queueHandoff(cn *pool.Conn) error { + // Create handoff request + request := HandoffRequest{ + Conn: cn, + ConnID: cn.GetID(), + Endpoint: cn.GetHandoffEndpoint(), + SeqID: cn.GetMovingSeqID(), + Pool: rcp.pool, // Include pool for connection removal on failure + } + + // Store in pending map + rcp.pending.Store(request.ConnID, request.SeqID) + + go rcp.queueHandoffWithTimeout(request, cn) + return nil +} + +// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure +// if err is returned, connection will be removed from pool +func (rcp *RedisConnectionProcessor) performConnectionHandoffWithPool(ctx context.Context, cn *pool.Conn, pooler pool.Pooler) error { + // Clear handoff state after successful handoff + seqID := cn.GetMovingSeqID() + connID := cn.GetID() + + // Notify hitless manager of completion if available + if rcp.hitlessManager != nil { + defer rcp.hitlessManager.CompleteOperationWithConnID(seqID, connID) + } + + newEndpoint := cn.GetHandoffEndpoint() + if newEndpoint == "" { + // TODO(hitless): maybe auto? + // Handle by performing the handoff to the current endpoint in N seconds, + // Where N is the time in the moving notification... + // Won't work for now! + cn.ClearHandoffState() + return nil + } + + retries := cn.IncrementAndGetHandoffRetries(1) + if retries > 3 { + if rcp.config != nil && rcp.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: reached max retries (%d) for handoff of connection %d to %s", + retries, cn.GetID(), cn.GetHandoffEndpoint()) + } + err := ErrMaxHandoffRetriesReached + if pooler != nil { + pooler.Remove(ctx, cn, err) + } else { + cn.Close() + internal.Logger.Printf(ctx, + "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + cn.GetID(), err) + } + return err + } + + // Create endpoint-specific dialer + endpointDialer := rcp.createEndpointDialer(newEndpoint) + + // Create new connection to the new endpoint + newNetConn, err := endpointDialer(ctx) + if err != nil { + // TODO(hitless): requeue the handoff request + // This is the only case where we should retry the handoff request + return err + } + + // Get the old connection + oldConn := cn.GetNetConn() + + // Replace the connection and execute initialization + err = cn.SetNetConnWithInitConn(ctx, newNetConn) + if err != nil { + // Remove the connection from the pool since it's in a bad state + if pooler != nil { + // Use pool.Pooler interface directly - no adapter needed + pooler.Remove(ctx, cn, err) + if rcp.config != nil && rcp.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed connection %d from pool due to handoff initialization failure: %v", + cn.GetID(), err) + } + } else { + cn.Close() + internal.Logger.Printf(ctx, + "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + cn.GetID(), err) + } + + // Keep the handoff state for retry + return err + } + // Note: CLIENT MAINT_NOTIFICATIONS is sent during client initialization, not per connection + defer func() { + if oldConn != nil { + oldConn.Close() + } + }() + + cn.ClearHandoffState() + + // Apply relaxed timeout to the new connection for the configured post-handoff duration + // This gives the new connection more time to handle operations during cluster transition + if rcp.config != nil && rcp.config.PostHandoffRelaxedDuration > 0 { + relaxedTimeout := rcp.config.RelaxedTimeout + postHandoffDuration := rcp.config.PostHandoffRelaxedDuration + + // Set relaxed timeout with deadline - no background goroutine needed + deadline := time.Now().Add(postHandoffDuration) + cn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) + + if rcp.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: applied post-handoff relaxed timeout (%v) until %v for connection %d", + relaxedTimeout, deadline.Format("15:04:05.000"), connID) + } + } + + return nil +} + +// createEndpointDialer creates a dialer function that connects to a specific endpoint +func (rcp *RedisConnectionProcessor) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + // Parse endpoint to extract host and port + host, port, err := net.SplitHostPort(endpoint) + if err != nil { + // If no port specified, assume default Redis port + host = endpoint + port = "6379" + } + + // Use the base dialer to connect to the new endpoint + return rcp.baseDialer(ctx, "tcp", net.JoinHostPort(host, port)) + } +} + +// Shutdown gracefully shuts down the processor, waiting for workers to complete +func (rcp *RedisConnectionProcessor) Shutdown(ctx context.Context) error { + rcp.shutdownOnce.Do(func() { + close(rcp.shutdown) + + // Clean up scale down timer + rcp.scaleDownMu.Lock() + if rcp.scaleDownTimer != nil { + rcp.scaleDownTimer.Stop() + rcp.scaleDownTimer = nil + } + rcp.scaleDownMu.Unlock() + }) + + // Wait for workers to complete + done := make(chan struct{}) + go func() { + rcp.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff +// and should not be used until the handoff is complete +var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff") diff --git a/hitless/redis_connection_processor_test.go b/hitless/redis_connection_processor_test.go new file mode 100644 index 0000000000..f922381383 --- /dev/null +++ b/hitless/redis_connection_processor_test.go @@ -0,0 +1,879 @@ +package hitless + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// mockNetConn implements net.Conn for testing +type mockNetConn struct { + addr string + shouldFailInit bool +} + +func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockAddr struct { + addr string +} + +func (m *mockAddr) Network() string { return "tcp" } +func (m *mockAddr) String() string { return m.addr } + +// createMockPoolConnection creates a mock pool connection for testing +func createMockPoolConnection() *pool.Conn { + mockNetConn := &mockNetConn{addr: "test:6379"} + conn := pool.NewConn(mockNetConn) + conn.SetUsable(true) // Make connection usable for testing + return conn +} + +// mockPool implements pool.Pooler for testing +type mockPool struct { + removedConnections map[uint64]bool + mu sync.Mutex +} + +func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) { + return nil, errors.New("not implemented") +} + +func (mp *mockPool) CloseConn(conn *pool.Conn) error { + return nil +} + +func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) { + return nil, errors.New("not implemented") +} + +func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) { + // Not implemented for testing +} + +func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) { + mp.mu.Lock() + defer mp.mu.Unlock() + + // Use pool.Conn directly - no adapter needed + mp.removedConnections[conn.GetID()] = true +} + +// WasRemoved safely checks if a connection was removed from the pool +func (mp *mockPool) WasRemoved(connID uint64) bool { + mp.mu.Lock() + defer mp.mu.Unlock() + return mp.removedConnections[connID] +} + +func (mp *mockPool) Len() int { + return 0 +} + +func (mp *mockPool) IdleLen() int { + return 0 +} + +func (mp *mockPool) Stats() *pool.Stats { + return &pool.Stats{} +} + +func (mp *mockPool) Close() error { + return nil +} + +// TestRedisConnectionProcessor tests the Redis connection processor functionality +func TestRedisConnectionProcessor(t *testing.T) { + // Create a base dialer for testing + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) { + config := &Config{ + Enabled: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue + LogLevel: 2, + } + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function with synchronization + initConnCalled := make(chan bool, 1) + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + select { + case initConnCalled <- true: + default: + } + return nil + } + conn.SetInitConnFunc(initConnFunc) + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should not error: %v", err) + } + + // Should pool the connection immediately (handoff queued) + if !shouldPool { + t.Error("Connection should be pooled immediately with event-driven handoff") + } + if shouldRemove { + t.Error("Connection should not be removed when queuing handoff") + } + + // Connection should be in pending map + if _, pending := processor.pending.Load(conn); !pending { + t.Error("Connection should be in pending handoffs map") + } + + // Wait for initialization to be called (indicates handoff started) + select { + case <-initConnCalled: + // Good, initialization was called + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for initialization function to be called") + } + + // Wait for handoff to complete with proper timeout and polling + timeout := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify handoff completed (removed from pending map) + if _, pending := processor.pending.Load(conn); pending { + t.Error("Connection should be removed from pending map after handoff") + } + + // Verify connection is usable again + if !conn.IsUsable() { + t.Error("Connection should be usable after successful handoff") + } + + // Verify handoff state is cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after completion") + } + }) + + t.Run("HandoffNotNeeded", func(t *testing.T) { + processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil) + conn := createMockPoolConnection() + // Don't mark for handoff + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should not error when handoff not needed: %v", err) + } + + // Should pool the connection normally + if !shouldPool { + t.Error("Connection should be pooled when no handoff needed") + } + if shouldRemove { + t.Error("Connection should not be removed when no handoff needed") + } + }) + + t.Run("EmptyEndpoint", func(t *testing.T) { + processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil) + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should not error with empty endpoint: %v", err) + } + + // Should pool the connection (empty endpoint clears state) + if !shouldPool { + t.Error("Connection should be pooled after clearing empty endpoint") + } + if shouldRemove { + t.Error("Connection should not be removed after clearing empty endpoint") + } + + // State should be cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after clearing empty endpoint") + } + }) + + t.Run("EventDrivenHandoffDialerError", func(t *testing.T) { + // Create a failing base dialer + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("dial failed") + } + + config := &Config{ + Enabled: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue + LogLevel: 2, + } + processor := NewRedisConnectionProcessor(3, failingDialer, config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should not return error to caller: %v", err) + } + + // Should pool the connection initially (handoff queued) + if !shouldPool { + t.Error("Connection should be pooled initially with event-driven handoff") + } + if shouldRemove { + t.Error("Connection should not be removed when queuing handoff") + } + + // Wait for handoff to complete and fail with proper timeout and polling + timeout := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for failed handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Connection should be removed from pending map after failed handoff + if _, pending := processor.pending.Load(conn); pending { + t.Error("Connection should be removed from pending map after failed handoff") + } + + // Handoff state should still be set (since handoff failed) + if !conn.ShouldHandoff() { + t.Error("Connection should still be marked for handoff after failed handoff") + } + }) + + t.Run("BufferedDataRESP2", func(t *testing.T) { + processor := NewRedisConnectionProcessor(2, baseDialer, nil, nil) + conn := createMockPoolConnection() + + // For this test, we'll just verify the logic works for connections without buffered data + // The actual buffered data detection is handled by the pool's connection health check + // which is outside the scope of the Redis connection processor + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should not error: %v", err) + } + + // Should pool the connection normally (no buffered data in mock) + if !shouldPool { + t.Error("Connection should be pooled when no buffered data") + } + if shouldRemove { + t.Error("Connection should not be removed when no buffered data") + } + }) + + t.Run("ProcessConnectionOnGet", func(t *testing.T) { + processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil) + conn := createMockPoolConnection() + + ctx := context.Background() + err := processor.ProcessConnectionOnGet(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnGet should not error for normal connection: %v", err) + } + }) + + t.Run("ProcessConnectionOnGetWithPendingHandoff", func(t *testing.T) { + config := &Config{ + Enabled: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue + LogLevel: 2, + } + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + + // Simulate a pending handoff by marking for handoff and queuing + conn.MarkForHandoff("new-endpoint:6379", 12345) + resultChan := make(chan HandoffResult, 1) + processor.pending.Store(conn, resultChan) + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + + ctx := context.Background() + err := processor.ProcessConnectionOnGet(ctx, conn) + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) + } + + // Clean up + processor.pending.Delete(conn) + }) + + t.Run("EventDrivenStateManagement", func(t *testing.T) { + processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + + // Test initial state - no pending handoffs + if _, pending := processor.pending.Load(conn); pending { + t.Error("New connection should not have pending handoffs") + } + + // Test adding to pending map + conn.MarkForHandoff("new-endpoint:6379", 12345) + resultChan := make(chan HandoffResult, 1) + processor.pending.Store(conn, resultChan) + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + + if _, pending := processor.pending.Load(conn); !pending { + t.Error("Connection should be in pending map") + } + + // Test ProcessConnectionOnGet with pending handoff + ctx := context.Background() + err := processor.ProcessConnectionOnGet(ctx, conn) + if err != ErrConnectionMarkedForHandoff { + t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") + } + + // Test removing from pending map and clearing handoff state + processor.pending.Delete(conn) + if _, pending := processor.pending.Load(conn); pending { + t.Error("Connection should be removed from pending map") + } + + // Clear handoff state to simulate completed handoff + conn.ClearHandoffState() + conn.SetUsable(true) // Make connection usable again + + // Test ProcessConnectionOnGet without pending handoff + err = processor.ProcessConnectionOnGet(ctx, conn) + if err != nil { + t.Errorf("Should not return error for non-pending connection: %v", err) + } + }) + + t.Run("EventDrivenQueueOptimization", func(t *testing.T) { + // Create processor with small queue to test optimization features + config := &Config{ + MinWorkers: 1, + MaxWorkers: 3, + HandoffQueueSize: 2, // Small queue to trigger optimizations + LogLevel: 3, // Debug level to see optimization logs + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Add small delay to simulate network latency + time.Sleep(10 * time.Millisecond) + return &mockNetConn{addr: addr}, nil + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + // Create multiple connections that need handoff to fill the queue + connections := make([]*pool.Conn, 5) + for i := 0; i < 5; i++ { + connections[i] = createMockPoolConnection() + if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil { + t.Fatalf("Failed to mark connection %d for handoff: %v", i, err) + } + // Set a mock initialization function + connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + } + + ctx := context.Background() + successCount := 0 + + // Process connections - should trigger scaling and timeout logic + for _, conn := range connections { + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Logf("ProcessConnectionOnPut returned error (expected with timeout): %v", err) + } + + if shouldPool && !shouldRemove { + successCount++ + } + } + + // With timeout and scaling, most handoffs should eventually succeed + if successCount == 0 { + t.Error("Should have queued some handoffs with timeout and scaling") + } + + t.Logf("Successfully queued %d handoffs with optimization features", successCount) + + // Give time for workers to process and scaling to occur + time.Sleep(100 * time.Millisecond) + }) + + t.Run("WorkerScalingBehavior", func(t *testing.T) { + // Create processor with small queue to test scaling behavior + config := &Config{ + MinWorkers: 1, + MaxWorkers: 4, + HandoffQueueSize: 1, // Very small queue to force scaling + LogLevel: 2, // Info level to see scaling logs + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + // Verify initial worker count and scaling level + if processor.currentWorkers != 1 { + t.Errorf("Expected 1 initial worker, got %d", processor.currentWorkers) + } + if processor.scaleLevel != 0 { + t.Errorf("Processor should be at scale level 0 initially, got %d", processor.scaleLevel) + } + if processor.minWorkers != 1 { + t.Errorf("Expected minWorkers=1, got %d", processor.minWorkers) + } + if processor.maxWorkers != 4 { + t.Errorf("Expected maxWorkers=4, got %d", processor.maxWorkers) + } + + // The scaling behavior is tested in other tests (ScaleDownDelayBehavior) + // This test just verifies the basic configuration is correct + t.Logf("Worker scaling configuration verified - Min: %d, Max: %d, Current: %d", + processor.minWorkers, processor.maxWorkers, processor.currentWorkers) + }) + + t.Run("PassiveTimeoutRestoration", func(t *testing.T) { + // Create processor with fast post-handoff duration for testing + config := &Config{ + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, + PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing + RelaxedTimeout: 5 * time.Second, + LogLevel: 2, + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + ctx := context.Background() + + // Create a connection and trigger handoff + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + // Process the connection to trigger handoff + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("Handoff should succeed: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after handoff") + } + + // Wait for handoff to complete with proper timeout and polling + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify relaxed timeout is set with deadline + if !conn.HasRelaxedTimeout() { + t.Error("Connection should have relaxed timeout after handoff") + } + + // Test that timeout is still active before deadline + // We'll use HasRelaxedTimeout which internally checks the deadline + if !conn.HasRelaxedTimeout() { + t.Error("Connection should still have active relaxed timeout before deadline") + } + + // Wait for deadline to pass + time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer + + // Test that timeout is automatically restored after deadline + // HasRelaxedTimeout should return false after deadline passes + if conn.HasRelaxedTimeout() { + t.Error("Connection should not have active relaxed timeout after deadline") + } + + // Additional verification: calling HasRelaxedTimeout again should still return false + // and should have cleared the internal timeout values + if conn.HasRelaxedTimeout() { + t.Error("Connection should not have relaxed timeout after deadline (second check)") + } + + t.Logf("Passive timeout restoration test completed successfully") + }) + + t.Run("UsableFlagBehavior", func(t *testing.T) { + config := &Config{ + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, + LogLevel: 2, + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + ctx := context.Background() + + // Create a new connection without setting it usable + mockNetConn := &mockNetConn{addr: "test:6379"} + conn := pool.NewConn(mockNetConn) + + // Initially, connection should not be usable (not initialized) + if conn.IsUsable() { + t.Error("New connection should not be usable before initialization") + } + + // Simulate initialization by setting usable to true + conn.SetUsable(true) + if !conn.IsUsable() { + t.Error("Connection should be usable after initialization") + } + + // ProcessConnectionOnGet should succeed for usable connection + err := processor.ProcessConnectionOnGet(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnGet should succeed for usable connection: %v", err) + } + + // Mark connection for handoff + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + // Connection should still be usable until queued, but marked for handoff + if !conn.IsUsable() { + t.Error("Connection should still be usable after being marked for handoff (until queued)") + } + if !conn.ShouldHandoff() { + t.Error("Connection should be marked for handoff") + } + + // ProcessConnectionOnGet should fail for connection marked for handoff + err = processor.ProcessConnectionOnGet(ctx, conn) + if err == nil { + t.Error("ProcessConnectionOnGet should fail for connection marked for handoff") + } + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) + } + + // Process the connection to trigger handoff + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should succeed: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after handoff") + } + + // Wait for handoff to complete + time.Sleep(50 * time.Millisecond) + + // After handoff completion, connection should be usable again + if !conn.IsUsable() { + t.Error("Connection should be usable after handoff completion") + } + + // ProcessConnectionOnGet should succeed again + err = processor.ProcessConnectionOnGet(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnGet should succeed after handoff completion: %v", err) + } + + t.Logf("Usable flag behavior test completed successfully") + }) + + t.Run("StaticQueueBehavior", func(t *testing.T) { + config := &Config{ + MinWorkers: 1, + MaxWorkers: 3, + HandoffQueueSize: 50, // Explicit static queue size + LogLevel: 2, + } + + processor := NewRedisConnectionProcessorWithPoolSize(3, baseDialer, config, nil, 100) // Pool size: 100 + defer processor.Shutdown(context.Background()) + + // Verify queue capacity matches configured size + queueCapacity := cap(processor.handoffQueue) + if queueCapacity != 50 { + t.Errorf("Expected queue capacity 50, got %d", queueCapacity) + } + + // Test that queue size is static regardless of pool size + // (No dynamic resizing should occur) + + ctx := context.Background() + + // Fill part of the queue + for i := 0; i < 10; i++ { + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil { + t.Fatalf("Failed to mark connection %d for handoff: %v", i, err) + } + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("Failed to queue handoff %d: %v", i, err) + } + if !shouldPool || shouldRemove { + t.Errorf("Connection %d should be pooled after handoff", i) + } + } + + // Verify queue capacity remains static (the main purpose of this test) + finalCapacity := cap(processor.handoffQueue) + if finalCapacity != 50 { + t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity) + } + + // Note: We don't check queue size here because workers process items quickly + // The important thing is that the capacity remains static regardless of pool size + currentQueueSize := len(processor.handoffQueue) + t.Logf("Static queue test completed - Capacity: %d, Current size: %d", + finalCapacity, currentQueueSize) + }) + + t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) { + // Create a failing dialer that will cause handoff initialization to fail + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Return a connection that will fail during initialization + return &mockNetConn{addr: addr, shouldFailInit: true}, nil + } + + config := &Config{ + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, + LogLevel: 2, + } + + processor := NewRedisConnectionProcessor(3, failingDialer, config, nil) + defer processor.Shutdown(context.Background()) + + // Create a mock pool that tracks removals + mockPool := &mockPool{removedConnections: make(map[uint64]bool)} + processor.SetPool(mockPool) + + ctx := context.Background() + + // Create a connection and mark it for handoff + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a failing initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return fmt.Errorf("initialization failed") + }) + + // Process the connection - handoff should fail and connection should be removed + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessConnectionOnPut should not error: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after failed handoff attempt") + } + + // Wait for handoff to be attempted and fail + time.Sleep(100 * time.Millisecond) + + // Verify that the connection was removed from the pool + if !mockPool.WasRemoved(conn.GetID()) { + t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID()) + } + + t.Logf("Connection removal on handoff failure test completed successfully") + }) + + t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) { + // Create config with short post-handoff duration for testing + config := &Config{ + MinWorkers: 1, + MaxWorkers: 2, + HandoffQueueSize: 10, + RelaxedTimeout: 5 * time.Second, + PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewRedisConnectionProcessor(3, baseDialer, config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.ProcessConnectionOnPut(ctx, conn) + + if err != nil { + t.Fatalf("ProcessConnectionOnPut failed: %v", err) + } + + if !shouldPool { + t.Error("Connection should be pooled after successful handoff") + } + + if shouldRemove { + t.Error("Connection should not be removed after successful handoff") + } + + // Wait for the handoff to complete (it happens asynchronously) + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify that relaxed timeout was applied to the new connection + if !conn.HasRelaxedTimeout() { + t.Error("New connection should have relaxed timeout applied after handoff") + } + + // Wait for the post-handoff duration to expire + time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration + + // Verify that relaxed timeout was automatically cleared + if conn.HasRelaxedTimeout() { + t.Error("Relaxed timeout should be automatically cleared after post-handoff duration") + } + }) + + t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) { + conn := createMockPoolConnection() + + // First mark should succeed + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("First MarkForHandoff should succeed: %v", err) + } + + // Second mark should fail + if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil { + t.Fatal("Second MarkForHandoff should return error") + } else if err.Error() != "connection is already marked for handoff" { + t.Fatalf("Expected specific error message, got: %v", err) + } + + // Verify original handoff data is preserved + if !conn.ShouldHandoff() { + t.Fatal("Connection should still be marked for handoff") + } + if conn.GetHandoffEndpoint() != "new-endpoint:6379" { + t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint()) + } + if conn.GetMovingSeqID() != 1 { + t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID()) + } + }) +} diff --git a/hitless/state.go b/hitless/state.go new file mode 100644 index 0000000000..109d939fc0 --- /dev/null +++ b/hitless/state.go @@ -0,0 +1,24 @@ +package hitless + +// State represents the current state of a hitless upgrade operation. +type State int + +const ( + // StateIdle indicates no upgrade is in progress + StateIdle State = iota + + // StateHandoff indicates a connection handoff is in progress + StateMoving +) + +// String returns a string representation of the state. +func (s State) String() string { + switch s { + case StateIdle: + return "idle" + case StateMoving: + return "moving" + default: + return "unknown" + } +} diff --git a/hset_benchmark_test.go b/hset_benchmark_test.go deleted file mode 100644 index df16343555..0000000000 --- a/hset_benchmark_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package redis_test - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/redis/go-redis/v9" -) - -// HSET Benchmark Tests -// -// This file contains benchmark tests for Redis HSET operations with different scales: -// 1, 10, 100, 1000, 10000, 100000 operations -// -// Prerequisites: -// - Redis server running on localhost:6379 -// - No authentication required -// -// Usage: -// go test -bench=BenchmarkHSET -v ./hset_benchmark_test.go -// go test -bench=BenchmarkHSETPipelined -v ./hset_benchmark_test.go -// go test -bench=. -v ./hset_benchmark_test.go # Run all benchmarks -// -// Example output: -// BenchmarkHSET/HSET_1_operations-8 5000 250000 ns/op 1000000.00 ops/sec -// BenchmarkHSET/HSET_100_operations-8 100 10000000 ns/op 100000.00 ops/sec -// -// The benchmarks test three different approaches: -// 1. Individual HSET commands (BenchmarkHSET) -// 2. Pipelined HSET commands (BenchmarkHSETPipelined) - -// BenchmarkHSET benchmarks HSET operations with different scales -func BenchmarkHSET(b *testing.B) { - ctx := context.Background() - - // Setup Redis client - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - DB: 0, - }) - defer rdb.Close() - - // Test connection - if err := rdb.Ping(ctx).Err(); err != nil { - b.Skipf("Redis server not available: %v", err) - } - - // Clean up before and after tests - defer func() { - rdb.FlushDB(ctx) - }() - - scales := []int{1, 10, 100, 1000, 10000, 100000} - - for _, scale := range scales { - b.Run(fmt.Sprintf("HSET_%d_operations", scale), func(b *testing.B) { - benchmarkHSETOperations(b, rdb, ctx, scale) - }) - } -} - -// benchmarkHSETOperations performs the actual HSET benchmark for a given scale -func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { - hashKey := fmt.Sprintf("benchmark_hash_%d", operations) - - b.ResetTimer() - b.StartTimer() - totalTimes := []time.Duration{} - - for i := 0; i < b.N; i++ { - b.StopTimer() - // Clean up the hash before each iteration - rdb.Del(ctx, hashKey) - b.StartTimer() - - startTime := time.Now() - // Perform the specified number of HSET operations - for j := 0; j < operations; j++ { - field := fmt.Sprintf("field_%d", j) - value := fmt.Sprintf("value_%d", j) - - err := rdb.HSet(ctx, hashKey, field, value).Err() - if err != nil { - b.Fatalf("HSET operation failed: %v", err) - } - } - totalTimes = append(totalTimes, time.Now().Sub(startTime)) - } - - // Stop the timer to calculate metrics - b.StopTimer() - - // Report operations per second - opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds() - b.ReportMetric(opsPerSec, "ops/sec") - - // Report average time per operation - avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) - b.ReportMetric(float64(avgTimePerOp), "ns/op") - // report average time in milliseconds from totalTimes - avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) - b.ReportMetric(float64(avgTimePerOpMs), "ms") -} - -// BenchmarkHSETPipelined benchmarks HSET operations using pipelining for better performance -func BenchmarkHSETPipelined(b *testing.B) { - ctx := context.Background() - - // Setup Redis client - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - DB: 0, - }) - defer rdb.Close() - - // Test connection - if err := rdb.Ping(ctx).Err(); err != nil { - b.Skipf("Redis server not available: %v", err) - } - - // Clean up before and after tests - defer func() { - rdb.FlushDB(ctx) - }() - - scales := []int{1, 10, 100, 1000, 10000, 100000} - - for _, scale := range scales { - b.Run(fmt.Sprintf("HSET_Pipelined_%d_operations", scale), func(b *testing.B) { - benchmarkHSETPipelined(b, rdb, ctx, scale) - }) - } -} - -// benchmarkHSETPipelined performs HSET benchmark using pipelining -func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { - hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations) - - b.ResetTimer() - b.StartTimer() - totalTimes := []time.Duration{} - - for i := 0; i < b.N; i++ { - b.StopTimer() - // Clean up the hash before each iteration - rdb.Del(ctx, hashKey) - b.StartTimer() - - startTime := time.Now() - // Use pipelining for better performance - pipe := rdb.Pipeline() - - // Add all HSET operations to the pipeline - for j := 0; j < operations; j++ { - field := fmt.Sprintf("field_%d", j) - value := fmt.Sprintf("value_%d", j) - pipe.HSet(ctx, hashKey, field, value) - } - - // Execute all operations at once - _, err := pipe.Exec(ctx) - if err != nil { - b.Fatalf("Pipeline execution failed: %v", err) - } - totalTimes = append(totalTimes, time.Now().Sub(startTime)) - } - - b.StopTimer() - - // Report operations per second - opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds() - b.ReportMetric(opsPerSec, "ops/sec") - - // Report average time per operation - avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) - b.ReportMetric(float64(avgTimePerOp), "ns/op") - // report average time in milliseconds from totalTimes - avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) - b.ReportMetric(float64(avgTimePerOpMs), "ms") -} - -// add same tests but with RESP2 -func BenchmarkHSET_RESP2(b *testing.B) { - ctx := context.Background() - - // Setup Redis client - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password docs - DB: 0, // use default DB - Protocol: 2, - }) - defer rdb.Close() - - // Test connection - if err := rdb.Ping(ctx).Err(); err != nil { - b.Skipf("Redis server not available: %v", err) - } - - // Clean up before and after tests - defer func() { - rdb.FlushDB(ctx) - }() - - scales := []int{1, 10, 100, 1000, 10000, 100000} - - for _, scale := range scales { - b.Run(fmt.Sprintf("HSET_RESP2_%d_operations", scale), func(b *testing.B) { - benchmarkHSETOperations(b, rdb, ctx, scale) - }) - } -} - -func BenchmarkHSETPipelined_RESP2(b *testing.B) { - ctx := context.Background() - - // Setup Redis client - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password docs - DB: 0, // use default DB - Protocol: 2, - }) - defer rdb.Close() - - // Test connection - if err := rdb.Ping(ctx).Err(); err != nil { - b.Skipf("Redis server not available: %v", err) - } - - // Clean up before and after tests - defer func() { - rdb.FlushDB(ctx) - }() - - scales := []int{1, 10, 100, 1000, 10000, 100000} - - for _, scale := range scales { - b.Run(fmt.Sprintf("HSET_Pipelined_RESP2_%d_operations", scale), func(b *testing.B) { - benchmarkHSETPipelined(b, rdb, ctx, scale) - }) - } -} diff --git a/internal/interfaces/interfaces.go b/internal/interfaces/interfaces.go new file mode 100644 index 0000000000..e78a07948f --- /dev/null +++ b/internal/interfaces/interfaces.go @@ -0,0 +1,79 @@ +// Package interfaces provides shared interfaces used by both the main redis package +// and the hitless upgrade package to avoid circular dependencies. +package interfaces + +import ( + "context" + "net" + "time" +) + +// Forward declaration to avoid circular imports +type NotificationProcessor interface { + RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error + UnregisterHandler(pushNotificationName string) error + GetHandler(pushNotificationName string) interface{} +} + +// ClientInterface defines the interface that clients must implement for hitless upgrades. +type ClientInterface interface { + // GetOptions returns the client options. + GetOptions() OptionsInterface + + // GetPushProcessor returns the client's push notification processor. + GetPushProcessor() NotificationProcessor +} + +// OptionsInterface defines the interface for client options. +type OptionsInterface interface { + // GetReadTimeout returns the read timeout. + GetReadTimeout() time.Duration + + // GetWriteTimeout returns the write timeout. + GetWriteTimeout() time.Duration + + // GetAddr returns the connection address. + GetAddr() string + + // IsTLSEnabled returns true if TLS is enabled. + IsTLSEnabled() bool + + // GetProtocol returns the protocol version. + GetProtocol() int + + // GetPoolSize returns the connection pool size. + GetPoolSize() int + + // NewDialer returns a new dialer function for the connection. + NewDialer() func(context.Context) (net.Conn, error) +} + +// ConnectionWithRelaxedTimeout defines the interface for connections that support relaxed timeout adjustment. +// This is used by the hitless upgrade system for per-connection timeout management. +type ConnectionWithRelaxedTimeout interface { + // SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. + // These timeouts remain active until explicitly cleared. + SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) + + // SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. + // After the deadline, timeouts automatically revert to normal values. + SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) + + // ClearRelaxedTimeout clears relaxed timeouts for this connection. + ClearRelaxedTimeout() +} + +// ConnectionProcessor defines the interface for processing connections in the pool. +// This allows different implementations (e.g., hitless upgrade processors) to be plugged in. +type ConnectionProcessor interface { + // ProcessConnectionOnGet is called when a connection is retrieved from the pool. + // It can modify the connection or return an error to prevent its use. + ProcessConnectionOnGet(ctx context.Context, conn interface{}) error + + // ProcessConnectionOnPut is called when a connection is returned to the pool. + // It returns whether the connection should be pooled and whether it should be removed. + ProcessConnectionOnPut(ctx context.Context, conn interface{}) (shouldPool bool, shouldRemove bool, err error) + + // Shutdown gracefully shuts down the processor. + Shutdown(ctx context.Context) error +} diff --git a/internal/pool/conn.go b/internal/pool/conn.go index edef9e6743..d3df1bbe2b 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -3,7 +3,10 @@ package pool import ( "bufio" "context" + "errors" + "fmt" "net" + "sync" "sync/atomic" "time" @@ -12,17 +15,63 @@ import ( var noDeadline = time.Time{} +// Global atomic counter for connection IDs +var connIDCounter uint64 + +// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value +type atomicNetConn struct { + conn net.Conn +} + +// generateConnID generates a fast unique identifier for a connection with zero allocations +func generateConnID() uint64 { + return atomic.AddUint64(&connIDCounter, 1) +} + type Conn struct { - usedAt int64 // atomic - netConn net.Conn + usedAt int64 // atomic + + // Lock-free netConn access using atomic.Value + // Contains *atomicNetConn wrapper, accessed atomically for better performance + netConnAtomic atomic.Value // stores *atomicNetConn rd *proto.Reader bw *bufio.Writer wr *proto.Writer + // Lightweight mutex to protect reader operations during handoff + // Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe + readerMu sync.RWMutex + Inited bool pooled bool createdAt time.Time + expiresAt time.Time + + // Hitless upgrade support: relaxed timeouts during migrations/failovers + // Using atomic operations for lock-free access to avoid mutex contention + relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch + + // Counter to track multiple relaxed timeout setters if we have nested calls + // will be decremented when ClearRelaxedTimeout is called or deadline is reached + // if counter reaches 0, we clear the relaxed timeouts + relaxedCounter atomic.Int32 + + // Connection initialization function for reconnections + initConnFunc func(context.Context, *Conn) error + + // Connection identifier for unique tracking across handoffs + id uint64 // Unique numeric identifier for this connection + + // Handoff state - using atomic operations for lock-free access + usableAtomic atomic.Bool // Connection usability state + shouldHandoffAtomic atomic.Bool // Whether connection should be handed off + movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification + handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts + // newEndpointAtomic needs special handling as it's a string + newEndpointAtomic atomic.Value // stores string onClose func() error } @@ -33,8 +82,8 @@ func NewConn(netConn net.Conn) *Conn { func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn { cn := &Conn{ - netConn: netConn, createdAt: time.Now(), + id: generateConnID(), // Generate unique ID for this connection } // Use specified buffer sizes, or fall back to 0.5MiB defaults if 0 @@ -50,6 +99,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize) } + // Store netConn atomically for lock-free access using wrapper + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) + + // Initialize atomic handoff state + cn.usableAtomic.Store(false) // false initially, set to true after initialization + cn.shouldHandoffAtomic.Store(false) // false initially + cn.movingSeqIDAtomic.Store(0) // 0 initially + cn.handoffRetriesAtomic.Store(0) // 0 initially + cn.newEndpointAtomic.Store("") // empty string initially + cn.wr = proto.NewWriter(cn.bw) cn.SetUsedAt(time.Now()) return cn @@ -64,23 +123,381 @@ func (cn *Conn) SetUsedAt(tm time.Time) { atomic.StoreInt64(&cn.usedAt, tm.Unix()) } +// getNetConn returns the current network connection using atomic load (lock-free). +// This is the fast path for accessing netConn without mutex overhead. +func (cn *Conn) getNetConn() net.Conn { + if v := cn.netConnAtomic.Load(); v != nil { + if wrapper, ok := v.(*atomicNetConn); ok { + return wrapper.conn + } + } + return nil +} + +// setNetConn stores the network connection atomically (lock-free). +// This is used for the fast path of connection replacement. +func (cn *Conn) setNetConn(netConn net.Conn) { + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) +} + +// Lock-free helper methods for handoff state management + +// isUsable returns true if the connection is safe to use (lock-free). +func (cn *Conn) isUsable() bool { + return cn.usableAtomic.Load() +} + +// setUsable sets the usable flag atomically (lock-free). +func (cn *Conn) setUsable(usable bool) { + cn.usableAtomic.Store(usable) +} + +// shouldHandoff returns true if connection needs handoff (lock-free). +func (cn *Conn) shouldHandoff() bool { + return cn.shouldHandoffAtomic.Load() +} + +// setShouldHandoff sets the handoff flag atomically (lock-free). +func (cn *Conn) setShouldHandoff(should bool) { + cn.shouldHandoffAtomic.Store(should) +} + +// getMovingSeqID returns the sequence ID atomically (lock-free). +func (cn *Conn) getMovingSeqID() int64 { + return cn.movingSeqIDAtomic.Load() +} + +// setMovingSeqID sets the sequence ID atomically (lock-free). +func (cn *Conn) setMovingSeqID(seqID int64) { + cn.movingSeqIDAtomic.Store(seqID) +} + +// getNewEndpoint returns the new endpoint atomically (lock-free). +func (cn *Conn) getNewEndpoint() string { + if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil { + return endpoint.(string) + } + return "" +} + +// setNewEndpoint sets the new endpoint atomically (lock-free). +func (cn *Conn) setNewEndpoint(endpoint string) { + cn.newEndpointAtomic.Store(endpoint) +} + +// setHandoffRetries sets the retry count atomically (lock-free). +func (cn *Conn) setHandoffRetries(retries int) { + cn.handoffRetriesAtomic.Store(uint32(retries)) +} + +// incrementHandoffRetries atomically increments and returns the new retry count (lock-free). +func (cn *Conn) incrementHandoffRetries(delta int) int { + return int(cn.handoffRetriesAtomic.Add(uint32(delta))) +} + +// IsUsable returns true if the connection is safe to use for new commands (lock-free). +func (cn *Conn) IsUsable() bool { + return cn.isUsable() +} + +// SetUsable sets the usable flag for the connection (lock-free). +func (cn *Conn) SetUsable(usable bool) { + cn.setUsable(usable) +} + +// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. +// These timeouts will be used for all subsequent commands until the deadline expires. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + cn.relaxedCounter.Add(1) + cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) + cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) +} + +// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. +// After the deadline, timeouts automatically revert to normal values. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + cn.relaxedCounter.Add(1) + cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) + cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) + cn.relaxedDeadlineNs.Store(deadline.UnixNano()) +} + +// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior. +// Uses atomic operations for lock-free access. +func (cn *Conn) ClearRelaxedTimeout() { + // Atomically decrement counter and check if we should clear + newCount := cn.relaxedCounter.Add(-1) + if newCount <= 0 { + // Use compare-and-swap to ensure only one goroutine clears + if cn.relaxedCounter.CompareAndSwap(newCount, 0) { + cn.clearRelaxedTimeout() + } + } +} + +func (cn *Conn) clearRelaxedTimeout() { + cn.relaxedReadTimeoutNs.Store(0) + cn.relaxedWriteTimeoutNs.Store(0) + cn.relaxedDeadlineNs.Store(0) + cn.relaxedCounter.Store(0) +} + +// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection. +// This checks both the timeout values and the deadline (if set). +// Uses atomic operations for lock-free access. +func (cn *Conn) HasRelaxedTimeout() bool { + // Fast path: no relaxed timeouts are set + if cn.relaxedCounter.Load() <= 0 { + return false + } + + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // If no relaxed timeouts are set, return false + if readTimeoutNs <= 0 && writeTimeoutNs <= 0 { + return false + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, relaxed timeouts are active + if deadlineNs == 0 { + return true + } + + // If deadline is set, check if it's still in the future + return time.Now().UnixNano() < deadlineNs +} + +// getEffectiveReadTimeout returns the timeout to use for read operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration { + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if readTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(readTimeoutNs) + } + + nowNs := time.Now().UnixNano() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(readTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + cn.relaxedCounter.Add(-1) + if cn.relaxedCounter.Load() <= 0 { + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + +// getEffectiveWriteTimeout returns the timeout to use for write operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration { + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if writeTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(writeTimeoutNs) + } + + nowNs := time.Now().UnixNano() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(writeTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + cn.relaxedCounter.Add(-1) + if cn.relaxedCounter.Load() <= 0 { + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + func (cn *Conn) SetOnClose(fn func() error) { cn.onClose = fn } +// SetInitConnFunc sets the connection initialization function to be called on reconnections. +func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) { + cn.initConnFunc = fn +} + +// ExecuteInitConn runs the stored connection initialization function if available. +func (cn *Conn) ExecuteInitConn(ctx context.Context) error { + if cn.initConnFunc != nil { + if err := cn.initConnFunc(ctx, cn); err != nil { + return err + } + cn.Inited = true + cn.setUsable(true) // Use atomic operation + return nil + } + return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID()) +} + func (cn *Conn) SetNetConn(netConn net.Conn) { - cn.netConn = netConn + // Store the new connection atomically first (lock-free) + cn.setNetConn(netConn) + // Clear relaxed timeouts when connection is replaced + cn.clearRelaxedTimeout() + + // Protect reader reset operations to avoid data races + // Use write lock since we're modifying the reader state + cn.readerMu.Lock() cn.rd.Reset(netConn) + cn.readerMu.Unlock() + cn.bw.Reset(netConn) } +// GetNetConn safely returns the current network connection using atomic load (lock-free). +// This method is used by the pool for health checks and provides better performance. +func (cn *Conn) GetNetConn() net.Conn { + return cn.getNetConn() +} + +// SetNetConnWithInitConn replaces the underlying connection and executes the initialization. +func (cn *Conn) SetNetConnWithInitConn(ctx context.Context, netConn net.Conn) error { + // New connection is not initialized yet + cn.Inited = false + // Replace the underlying connection + cn.SetNetConn(netConn) + return cn.ExecuteInitConn(ctx) +} + +// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free). +// Returns an error if the connection is already marked for handoff. +func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error { + // Use single atomic CAS operation for state transition + if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) { + return errors.New("connection is already marked for handoff") + } + + cn.setNewEndpoint(newEndpoint) + cn.setMovingSeqID(seqID) + return nil +} + +func (cn *Conn) MarkQueuedForHandoff() error { + // Use single atomic CAS operation for state transition + if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) { + return errors.New("connection was not marked for handoff") + } + cn.setUsable(false) + return nil +} + +// RestoreHandoffState restores the handoff state after a failed handoff (lock-free). +func (cn *Conn) RestoreHandoffState() { + // Restore shouldHandoff flag for retry + cn.shouldHandoffAtomic.Store(true) + // Keep usable=false to prevent the connection from being used until handoff succeeds + cn.setUsable(false) +} + +// ShouldHandoff returns true if the connection needs to be handed off (lock-free). +func (cn *Conn) ShouldHandoff() bool { + return cn.shouldHandoff() +} + +// GetHandoffEndpoint returns the new endpoint for handoff (lock-free). +func (cn *Conn) GetHandoffEndpoint() string { + return cn.getNewEndpoint() +} + +// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free). +func (cn *Conn) GetMovingSeqID() int64 { + return cn.getMovingSeqID() +} + +// GetID returns the unique identifier for this connection. +func (cn *Conn) GetID() uint64 { + return cn.id +} + +// ClearHandoffState clears the handoff state after successful handoff (lock-free). +func (cn *Conn) ClearHandoffState() { + // clear handoff state + cn.setShouldHandoff(false) + cn.setNewEndpoint("") + cn.setMovingSeqID(0) + cn.setHandoffRetries(0) + cn.setUsable(true) // Connection is safe to use again after handoff completes +} + +// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free). +func (cn *Conn) IncrementAndGetHandoffRetries(n int) int { + return cn.incrementHandoffRetries(n) +} + +// Rd returns the connection's reader for protocol-specific processing +func (cn *Conn) Rd() *proto.Reader { + return cn.rd +} + +// Reader returns the connection's proto reader for processing notifications +// Note: This method should be used carefully as it returns the raw reader. +// For thread-safe operations, use HasBufferedData() and PeekReplyTypeSafe(). +func (cn *Conn) Reader() *proto.Reader { + return cn.rd +} + +// HasBufferedData safely checks if the connection has buffered data. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) HasBufferedData() bool { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + return cn.rd.Buffered() > 0 +} + +// PeekReplyTypeSafe safely peeks at the reply type. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) PeekReplyTypeSafe() (byte, error) { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + + if cn.rd.Buffered() <= 0 { + return 0, fmt.Errorf("redis: can't peek reply type, no data available") + } + return cn.rd.PeekReplyType() +} + func (cn *Conn) Write(b []byte) (int, error) { - return cn.netConn.Write(b) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Write(b) + } + return 0, net.ErrClosed } func (cn *Conn) RemoteAddr() net.Addr { - if cn.netConn != nil { - return cn.netConn.RemoteAddr() + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.RemoteAddr() } return nil } @@ -89,7 +506,16 @@ func (cn *Conn) WithReader( ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveReadTimeout(timeout) + + // Get the connection directly from atomic storage + netConn := cn.getNetConn() + if netConn == nil { + return fmt.Errorf("redis: connection not available") + } + + if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { return err } } @@ -100,13 +526,26 @@ func (cn *Conn) WithWriter( ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { - return err + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveWriteTimeout(timeout) + + // Always set write deadline, even if getNetConn() returns nil + // This prevents write operations from hanging indefinitely + if netConn := cn.getNetConn(); netConn != nil { + if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { + return err + } + } else { + // If getNetConn() returns nil, we still need to respect the timeout + // Return an error to prevent indefinite blocking + return fmt.Errorf("redis: connection not available for write operation") } } if cn.bw.Buffered() > 0 { - cn.bw.Reset(cn.netConn) + if netConn := cn.getNetConn(); netConn != nil { + cn.bw.Reset(netConn) + } } if err := fn(cn.wr); err != nil { @@ -121,14 +560,23 @@ func (cn *Conn) Close() error { // ignore error _ = cn.onClose() } - return cn.netConn.Close() + + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Close() + } + return nil } // MaybeHasData tries to peek at the next byte in the socket without consuming it // This is used to check if there are push notifications available // Important: This will work on Linux, but not on Windows func (cn *Conn) MaybeHasData() bool { - return maybeHasData(cn.netConn) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return maybeHasData(netConn) + } + return false } func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index 40e387c9a0..20456b8100 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) { } func (cn *Conn) NetConn() net.Conn { - return cn.netConn + return cn.getNetConn() } func (p *ConnPool) CheckMinIdleConns() { diff --git a/internal/pool/pool.go b/internal/pool/pool.go index fa0306c3b9..30dfff0fd7 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -3,13 +3,14 @@ package pool import ( "context" "errors" + "log" "net" "sync" "sync/atomic" "time" "github.com/redis/go-redis/v9/internal" - "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/interfaces" ) var ( @@ -22,6 +23,10 @@ var ( // ErrPoolTimeout timed out waiting to get a connection from the connection pool. ErrPoolTimeout = errors.New("redis: connection pool timeout") + + minTime = time.Unix(-2208988800, 0) // Jan 1, 1900 + maxTime = minTime.Add(1<<63 - 1) + noExpiration = maxTime ) var timers = sync.Pool{ @@ -38,11 +43,13 @@ type Stats struct { Misses uint32 // number of times free connection was NOT found in the pool Timeouts uint32 // number of times a wait timeout occurred WaitCount uint32 // number of times a connection was waited + Unusable uint32 // number of times a connection was found to be unusable WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds - TotalConns uint32 // number of total connections in the pool - IdleConns uint32 // number of idle connections in the pool - StaleConns uint32 // number of stale connections removed from the pool + TotalConns uint32 // number of total connections in the pool + IdleConns uint32 // number of idle connections in the pool + StaleConns uint32 // number of stale connections removed from the pool + PubSubStats PubSubStats } type Pooler interface { @@ -61,7 +68,9 @@ type Pooler interface { } type Options struct { - Dialer func(context.Context) (net.Conn, error) + Dialer func(context.Context) (net.Conn, error) + ReadBufferSize int + WriteBufferSize int PoolFIFO bool PoolSize int @@ -73,13 +82,9 @@ type Options struct { ConnMaxIdleTime time.Duration ConnMaxLifetime time.Duration - - // Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without) - Protocol int - - ReadBufferSize int - WriteBufferSize int - + // ConnectionProcessor handles protocol-specific connection processing + // If nil, connections are processed with default behavior + ConnectionProcessor interfaces.ConnectionProcessor } type lastDialErrorWrap struct { @@ -98,7 +103,7 @@ type ConnPool struct { conns []*Conn idleConns []*Conn - poolSize int + poolSize atomic.Int32 idleConnsLen int stats Stats @@ -118,9 +123,13 @@ func NewConnPool(opt *Options) *ConnPool { idleConns: make([]*Conn, 0, opt.PoolSize), } - p.connsMu.Lock() - p.checkMinIdleConns() - p.connsMu.Unlock() + // Only create MinIdleConns if explicitly requested (> 0) + // This avoids creating connections during pool initialization for tests + if opt.MinIdleConns > 0 { + p.connsMu.Lock() + p.checkMinIdleConns() + p.connsMu.Unlock() + } return p } @@ -129,17 +138,19 @@ func (p *ConnPool) checkMinIdleConns() { if p.cfg.MinIdleConns == 0 { return } - for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns { + + // Only create idle connections if we haven't reached the total pool size limit + // MinIdleConns should be a subset of PoolSize, not additional connections + for p.poolSize.Load() < int32(p.cfg.PoolSize) && p.idleConnsLen < p.cfg.MinIdleConns { select { case p.queue <- struct{}{}: - p.poolSize++ + p.poolSize.Add(1) p.idleConnsLen++ - go func() { defer func() { if err := recover(); err != nil { p.connsMu.Lock() - p.poolSize-- + p.poolSize.Add(-1) p.idleConnsLen-- p.connsMu.Unlock() @@ -151,11 +162,10 @@ func (p *ConnPool) checkMinIdleConns() { err := p.addIdleConn() if err != nil && err != ErrClosed { p.connsMu.Lock() - p.poolSize-- + p.poolSize.Add(-1) p.idleConnsLen-- p.connsMu.Unlock() } - p.freeTurn() }() default: @@ -172,6 +182,9 @@ func (p *ConnPool) addIdleConn() error { if err != nil { return err } + // Mark connection as usable after successful creation + // This is essential for normal pool operations + cn.SetUsable(true) p.connsMu.Lock() defer p.connsMu.Unlock() @@ -187,6 +200,10 @@ func (p *ConnPool) addIdleConn() error { return nil } +// NewConn creates a new connection and returns it to the user. +// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size. +// +// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades. func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.newConn(ctx, false) } @@ -197,7 +214,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { } p.connsMu.Lock() - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) { p.connsMu.Unlock() return nil, ErrPoolExhausted } @@ -207,11 +224,14 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { if err != nil { return nil, err } + // Mark connection as usable after successful creation + // This is essential for normal pool operations + cn.SetUsable(true) p.connsMu.Lock() defer p.connsMu.Unlock() - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { _ = cn.Close() return nil, ErrPoolExhausted } @@ -219,10 +239,11 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { p.conns = append(p.conns, cn) if pooled { // If pool is full remove the cn on next Put. - if p.poolSize >= p.cfg.PoolSize { + currentPoolSize := p.poolSize.Load() + if currentPoolSize >= int32(p.cfg.PoolSize) { cn.pooled = false } else { - p.poolSize++ + p.poolSize.Add(1) } } @@ -249,6 +270,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) cn.pooled = pooled + if p.cfg.ConnMaxLifetime > 0 { + cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime) + } else { + cn.expiresAt = noExpiration + } + return cn, nil } @@ -289,6 +316,11 @@ func (p *ConnPool) getLastDialError() error { // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { + return p.getConn(ctx) +} + +// getConn returns a connection from the pool. +func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { if p.closed() { return nil, ErrClosed } @@ -297,7 +329,14 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return nil, err } + tries := 0 + now := time.Now() for { + if tries > 10 { + log.Printf("redis: connection pool: failed to get a connection after %d tries", tries) + break + } + tries++ p.connsMu.Lock() cn, err := p.popIdle() p.connsMu.Unlock() @@ -311,11 +350,21 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { break } - if !p.isHealthyConn(cn) { + if !p.isHealthyConn(cn, now) { _ = p.CloseConn(cn) continue } + // Process connection using the connection processor if available + // Fast path: check processor existence once and cache the result + if processor := p.cfg.ConnectionProcessor; processor != nil { + if err := processor.ProcessConnectionOnGet(ctx, cn); err != nil { + // Failed to process connection, discard it + _ = p.CloseConn(cn) + continue + } + } + atomic.AddUint32(&p.stats.Hits, 1) return cn, nil } @@ -356,7 +405,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error { } return ctx.Err() case p.queue <- struct{}{}: - p.waitDurationNs.Add(time.Since(start).Nanoseconds()) + p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) atomic.AddUint32(&p.stats.WaitCount, 1) if !timer.Stop() { <-timer.C @@ -382,43 +431,79 @@ func (p *ConnPool) popIdle() (*Conn, error) { } var cn *Conn - if p.cfg.PoolFIFO { - cn = p.idleConns[0] - copy(p.idleConns, p.idleConns[1:]) - p.idleConns = p.idleConns[:n-1] - } else { - idx := n - 1 - cn = p.idleConns[idx] - p.idleConns = p.idleConns[:idx] + attempts := 0 + maxAttempts := len(p.idleConns) + 1 // Prevent infinite loop + + for attempts < maxAttempts { + if len(p.idleConns) == 0 { + return nil, nil + } + + if p.cfg.PoolFIFO { + cn = p.idleConns[0] + copy(p.idleConns, p.idleConns[1:]) + p.idleConns = p.idleConns[:len(p.idleConns)-1] + } else { + idx := len(p.idleConns) - 1 + cn = p.idleConns[idx] + p.idleConns = p.idleConns[:idx] + } + attempts++ + + if cn.IsUsable() { + p.idleConnsLen-- + break + } + + // Connection is not usable, put it back in the pool + if p.cfg.PoolFIFO { + // FIFO: put at end (will be picked up last since we pop from front) + p.idleConns = append(p.idleConns, cn) + } else { + // LIFO: put at beginning (will be picked up last since we pop from end) + // currently isUsable is only set for hitless upgrades, so this is a no-op + // but we may need it in the future, so leaving it here for now. + p.idleConns = append([]*Conn{cn}, p.idleConns...) + } } - p.idleConnsLen-- + + // If we exhausted all attempts without finding a usable connection, return nil + if attempts >= maxAttempts { + return nil, nil + } + p.checkMinIdleConns() return cn, nil } func (p *ConnPool) Put(ctx context.Context, cn *Conn) { + // Process connection using the connection processor if available + shouldPool := true shouldRemove := false - if cn.rd.Buffered() > 0 { - // Check if this might be push notification data - if p.cfg.Protocol == 3 { - // we know that there is something in the buffer, so peek at the next reply type without - // the potential to block and check if it's a push notification - if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush { - shouldRemove = true - } - } else { - // not a push notification since protocol 2 doesn't support them - shouldRemove = true - } + var err error - if shouldRemove { - // For non-RESP3 or data that is not a push notification, buffered data is unexpected - internal.Logger.Printf(ctx, "Conn has unread data, closing it") - p.Remove(ctx, cn, BadConnError{}) + // Fast path: cache processor reference to avoid repeated field access + if processor := p.cfg.ConnectionProcessor; processor != nil { + shouldPool, shouldRemove, err = processor.ProcessConnectionOnPut(ctx, cn) + if err != nil { + internal.Logger.Printf(ctx, "Connection processor error: %v", err) + p.Remove(ctx, cn, err) return } } + // If processor says to remove the connection, do so + if shouldRemove { + p.Remove(ctx, cn, nil) + return + } + + // If processor says not to pool the connection, remove it + if !shouldPool { + p.Remove(ctx, cn, nil) + return + } + if !cn.pooled { p.Remove(ctx, cn, nil) return @@ -467,7 +552,8 @@ func (p *ConnPool) removeConn(cn *Conn) { if c == cn { p.conns = append(p.conns[:i], p.conns[i+1:]...) if cn.pooled { - p.poolSize-- + p.poolSize.Add(-1) + // Immediately check for minimum idle connections when a pooled connection is removed p.checkMinIdleConns() } break @@ -502,6 +588,7 @@ func (p *ConnPool) Stats() *Stats { Misses: atomic.LoadUint32(&p.stats.Misses), Timeouts: atomic.LoadUint32(&p.stats.Timeouts), WaitCount: atomic.LoadUint32(&p.stats.WaitCount), + Unusable: atomic.LoadUint32(&p.stats.Unusable), WaitDurationNs: p.waitDurationNs.Load(), TotalConns: uint32(p.Len()), @@ -542,7 +629,7 @@ func (p *ConnPool) Close() error { } } p.conns = nil - p.poolSize = 0 + p.poolSize.Store(0) p.idleConns = nil p.idleConnsLen = 0 p.connsMu.Unlock() @@ -550,34 +637,20 @@ func (p *ConnPool) Close() error { return firstErr } -func (p *ConnPool) isHealthyConn(cn *Conn) bool { - now := time.Now() - - if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime { +func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { + // slight optimization, check expiresAt first. + if cn.expiresAt.Before(now) { return false } + if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { return false } - // Check connection health, but be aware of push notifications - if err := connCheck(cn.netConn); err != nil { - // If there's unexpected data, it might be push notifications (RESP3) - // However, push notification processing is now handled by the client - // before WithReader to ensure proper context is available to handlers - if err == errUnexpectedRead && p.cfg.Protocol == 3 { - // we know that there is something in the buffer, so peek at the next reply type without - // the potential to block - if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { - // For RESP3 connections with push notifications, we allow some buffered data - // The client will process these notifications before using the connection - internal.Logger.Printf(context.Background(), "push: connection has buffered data, likely push notifications - will be processed by client") - return true // Connection is healthy, client will handle notifications - } - return false // Unexpected data, not push notifications, connection is unhealthy - } else { - return false - } + // Check basic connection health + // Use GetNetConn() to safely access netConn and avoid data races + if err := connCheck(cn.getNetConn()); err != nil { + return false } cn.SetUsedAt(now) diff --git a/internal/pool/pubsub_bench_test.go b/internal/pool/pubsub_bench_test.go new file mode 100644 index 0000000000..93716adeff --- /dev/null +++ b/internal/pool/pubsub_bench_test.go @@ -0,0 +1,368 @@ +package pool_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// PubSub Pool Benchmark Suite +// +// This file contains comprehensive benchmarks for PubSub pool operations. +// PubSub pools have different characteristics than regular pools: +// - Connections are never pooled (always created fresh) +// - Connections are always removed (never reused) +// - No pool size limits apply +// - Focus on creation/destruction performance +// +// Usage Examples: +// # Run all PubSub benchmarks +// go test -bench=BenchmarkPubSub -run='^$' internal/pool/pubsub_bench_test.go internal/pool/main_test.go +// +// # Run specific benchmark +// go test -bench=BenchmarkPubSubGetRemove -run='^$' internal/pool/pubsub_bench_test.go internal/pool/main_test.go +// +// # Compare with regular pool benchmarks +// go test -bench=. -run='^$' internal/pool/ + +type pubsubGetRemoveBenchmark struct { + name string +} + +func (bm pubsubGetRemoveBenchmark) String() string { + return bm.name +} + +// BenchmarkPubSubGetRemove benchmarks the core PubSub pool operation: +// Get a connection and immediately remove it (PubSub connections are never pooled) +func BenchmarkPubSubGetRemove(b *testing.B) { + ctx := context.Background() + benchmarks := []pubsubGetRemoveBenchmark{ + {"sequential"}, + {"parallel"}, + } + + for _, bm := range benchmarks { + b.Run(bm.String(), func(b *testing.B) { + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + if bm.name == "parallel" { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + pubsubPool.Remove(ctx, cn, nil) + } + }) + } else { + for i := 0; i < b.N; i++ { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + pubsubPool.Remove(ctx, cn, nil) + } + } + }) + } +} + +// BenchmarkPubSubConcurrentAccess benchmarks concurrent access patterns +// typical in PubSub scenarios with multiple subscribers +func BenchmarkPubSubConcurrentAccess(b *testing.B) { + ctx := context.Background() + concurrencyLevels := []int{1, 2, 4, 8, 16, 32} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("goroutines=%d", concurrency), func(b *testing.B) { + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + var wg sync.WaitGroup + opsPerGoroutine := b.N / concurrency + if opsPerGoroutine == 0 { + opsPerGoroutine = 1 + } + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Error(err) + return + } + pubsubPool.Remove(ctx, cn, nil) + } + }() + } + wg.Wait() + }) + } +} + +// BenchmarkPubSubConnectionLifecycle benchmarks the full lifecycle +// of PubSub connections including creation, usage simulation, and cleanup +func BenchmarkPubSubConnectionLifecycle(b *testing.B) { + ctx := context.Background() + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Get connection + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + + // Simulate some work (minimal to focus on pool overhead) + cn.SetUsable(true) + + // Remove connection (PubSub connections are never pooled) + pubsubPool.Remove(ctx, cn, nil) + } + }) +} + +// BenchmarkPubSubStats benchmarks statistics collection performance +func BenchmarkPubSubStats(b *testing.B) { + ctx := context.Background() + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + // Pre-create some connections to have meaningful stats + var connections []*pool.Conn + for i := 0; i < 10; i++ { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connections = append(connections, cn) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = pubsubPool.Stats() + } + + // Cleanup + for _, cn := range connections { + pubsubPool.Remove(ctx, cn, nil) + } +} + +// BenchmarkPubSubVsRegularPool compares PubSub pool performance +// with regular pool performance for similar operations +func BenchmarkPubSubVsRegularPool(b *testing.B) { + ctx := context.Background() + + b.Run("PubSubPool", func(b *testing.B) { + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + pubsubPool.Remove(ctx, cn, nil) + } + }) + + b.Run("RegularPool", func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 1, // Small pool to force creation/removal + PoolTimeout: time.Second, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Remove(ctx, cn, nil) // Force removal like PubSub + } + }) +} + +// BenchmarkPubSubMemoryUsage benchmarks memory allocation patterns +func BenchmarkPubSubMemoryUsage(b *testing.B) { + ctx := context.Background() + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + // Focus on memory allocations + for i := 0; i < b.N; i++ { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + pubsubPool.Put(ctx, cn) // Put calls Remove internally for PubSub + } +} + +// BenchmarkPubSubBurstLoad benchmarks handling of burst connection requests +// typical in PubSub scenarios where many subscribers connect simultaneously +func BenchmarkPubSubBurstLoad(b *testing.B) { + ctx := context.Background() + burstSizes := []int{10, 50, 100, 500} + + for _, burstSize := range burstSizes { + b.Run(fmt.Sprintf("burst=%d", burstSize), func(b *testing.B) { + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + var connections []*pool.Conn + var mu sync.Mutex + + // Simulate burst of connections + for j := 0; j < burstSize; j++ { + wg.Add(1) + go func() { + defer wg.Done() + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Error(err) + return + } + mu.Lock() + connections = append(connections, cn) + mu.Unlock() + }() + } + wg.Wait() + + // Clean up all connections + for _, cn := range connections { + pubsubPool.Remove(ctx, cn, nil) + } + } + }) + } +} + +// BenchmarkPubSubLongRunning benchmarks long-running PubSub connections +// that stay active for extended periods +func BenchmarkPubSubLongRunning(b *testing.B) { + ctx := context.Background() + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + defer pubsubPool.Close() + + // Create long-running connections + var connections []*pool.Conn + for i := 0; i < 100; i++ { + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connections = append(connections, cn) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark operations while long-running connections exist + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Create short-lived connection while long-running ones exist + cn, err := pubsubPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + pubsubPool.Remove(ctx, cn, nil) + } + }) + + // Cleanup long-running connections + for _, cn := range connections { + pubsubPool.Remove(ctx, cn, nil) + } +} + +// BenchmarkPubSubErrorHandling benchmarks error scenarios +func BenchmarkPubSubErrorHandling(b *testing.B) { + ctx := context.Background() + + // Create a pool that will be closed to trigger errors + pubsubPool := pool.NewPubSubPool(&pool.Config{ + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, + }, dummyDialer) + + // Close the pool to trigger error conditions + pubsubPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // This should return an error since pool is closed + _, err := pubsubPool.Get(ctx) + if err == nil { + b.Fatal("Expected error from closed pool") + } + } +} diff --git a/internal/pool/pubsub_pool.go b/internal/pool/pubsub_pool.go new file mode 100644 index 0000000000..21226f7463 --- /dev/null +++ b/internal/pool/pubsub_pool.go @@ -0,0 +1,155 @@ +package pool + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" +) + +// Config contains configuration for PubSub connections. +type Config struct { + PoolFIFO bool + PoolSize int + PoolTimeout time.Duration + DialTimeout time.Duration + MinIdleConns int + MaxIdleConns int + MaxActiveConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + + // ConnectionProcessor handles protocol-specific connection processing + ConnectionProcessor interfaces.ConnectionProcessor +} + +// PubSubPool manages connections specifically for PubSub operations. +// Unlike regular connections, PubSub connections are not pooled and are +// immediately closed when no longer needed. +type PubSubPool struct { + cfg *Config + dialer func(context.Context) (net.Conn, error) + + mu sync.RWMutex + connections map[*Conn]struct{} + closed bool + + stats struct { + created uint32 + active uint32 + closed uint32 + } +} + +func NewPubSubPool(cfg *Config, dialer func(context.Context) (net.Conn, error)) *PubSubPool { + return &PubSubPool{ + cfg: cfg, + dialer: dialer, + connections: make(map[*Conn]struct{}), + } +} + +// Get creates a new PubSub connection. PubSub connections are never pooled. +func (p *PubSubPool) Get(ctx context.Context) (*Conn, error) { + p.mu.RLock() + if p.closed { + p.mu.RUnlock() + return nil, ErrClosed + } + p.mu.RUnlock() + + // Create new connection + netConn, err := p.dialer(ctx) + if err != nil { + return nil, err + } + + cn := NewConn(netConn) + cn.pooled = false + + // Process connection if processor is available + if processor := p.cfg.ConnectionProcessor; processor != nil { + if err := processor.ProcessConnectionOnGet(ctx, cn); err != nil { + _ = cn.Close() + return nil, err + } + } + + // Track the connection + p.mu.Lock() + if p.closed { + p.mu.Unlock() + _ = cn.Close() + return nil, ErrClosed + } + p.connections[cn] = struct{}{} + p.mu.Unlock() + + atomic.AddUint32(&p.stats.created, 1) + atomic.AddUint32(&p.stats.active, 1) + + return cn, nil +} + +// Put closes the PubSub connection immediately. PubSub connections are never reused. +func (p *PubSubPool) Put(ctx context.Context, cn *Conn) { + p.Remove(ctx, cn, nil) +} + +// Remove closes and removes the PubSub connection. +func (p *PubSubPool) Remove(ctx context.Context, cn *Conn, err error) { + p.mu.Lock() + if _, exists := p.connections[cn]; exists { + delete(p.connections, cn) + atomic.AddUint32(&p.stats.active, ^uint32(0)) // decrement + atomic.AddUint32(&p.stats.closed, 1) + } + p.mu.Unlock() + + // Process connection before closing if processor is available + if processor := p.cfg.ConnectionProcessor; processor != nil { + _, _, _ = processor.ProcessConnectionOnPut(ctx, cn) + } + + _ = cn.Close() +} + +// Stats returns PubSub connection statistics. +func (p *PubSubPool) Stats() PubSubStats { + return PubSubStats{ + Created: atomic.LoadUint32(&p.stats.created), + Active: atomic.LoadUint32(&p.stats.active), + Closed: atomic.LoadUint32(&p.stats.closed), + } +} + +// Close closes all active PubSub connections. +func (p *PubSubPool) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return ErrClosed + } + p.closed = true + + // Close all active connections + for cn := range p.connections { + _ = cn.Close() + atomic.AddUint32(&p.stats.active, ^uint32(0)) // decrement + atomic.AddUint32(&p.stats.closed, 1) + } + p.connections = nil + + return nil +} + +// PubSubStats contains statistics for PubSub connections. +type PubSubStats struct { + Created uint32 // Total PubSub connections created + Active uint32 // Currently active PubSub connections + Closed uint32 // Total PubSub connections closed +} diff --git a/internal/pool/simple_pool_test.go b/internal/pool/simple_pool_test.go new file mode 100644 index 0000000000..4494e424fd --- /dev/null +++ b/internal/pool/simple_pool_test.go @@ -0,0 +1,243 @@ +package pool + +import ( + "context" + "net" + "sync" + "testing" + "time" +) + +func TestPoolConnectionLimit(t *testing.T) { + ctx := context.Background() + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil + }, + PoolSize: 1000, + MinIdleConns: 50, + PoolTimeout: 3 * time.Second, + DialTimeout: 1 * time.Second, + } + p := NewConnPool(opt) + defer p.Close() + + var wg sync.WaitGroup + for i := 0; i < opt.PoolSize; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = p.Get(ctx) + }() + } + wg.Wait() + + stats := p.Stats() + + // Current pool implementation has issues: + // 1. It creates PoolSize + MinIdleConns connections instead of max(PoolSize, MinIdleConns) + // 2. It maintains idle connections even when all should be in use + // + // Expected behavior (what the test should pass with): + // - IdleConns should be 0 (all connections are held by goroutines) + // - TotalConns should be PoolSize (1000) + // + // Current actual behavior: + // - IdleConns = MinIdleConns (50) - incorrect + // - TotalConns = PoolSize + MinIdleConns (1050) - incorrect + + t.Logf("Current stats: IdleConns=%d, TotalConns=%d", stats.IdleConns, stats.TotalConns) + t.Logf("Expected stats: IdleConns=0, TotalConns=%d", opt.PoolSize) + + // TODO: Fix pool implementation to make these assertions pass + if stats.IdleConns != 0 { + t.Errorf("Expected IdleConns to be 0, got %d (pool implementation bug)", stats.IdleConns) + } + if stats.TotalConns != uint32(opt.PoolSize) { + t.Errorf("Expected TotalConns to be %d, got %d (pool implementation bug)", opt.PoolSize, stats.TotalConns) + } +} + +func TestPoolBasicGetPut(t *testing.T) { + ctx := context.Background() + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil + }, + PoolSize: 10, + MinIdleConns: 2, + PoolTimeout: 1 * time.Second, + DialTimeout: 1 * time.Second, + } + p := NewConnPool(opt) + defer p.Close() + + // Get a connection + conn, err := p.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + // Put it back + p.Put(ctx, conn) + + // Verify stats + stats := p.Stats() + if stats.TotalConns == 0 { + t.Error("Expected at least one connection in pool") + } +} + +func TestPoolConcurrentGetPut(t *testing.T) { + ctx := context.Background() + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil + }, + PoolSize: 50, + MinIdleConns: 5, + PoolTimeout: 2 * time.Second, + DialTimeout: 1 * time.Second, + } + p := NewConnPool(opt) + defer p.Close() + + const numGoroutines = 100 + const numOperations = 10 + + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < numOperations; j++ { + conn, err := p.Get(ctx) + if err != nil { + t.Errorf("Failed to get connection: %v", err) + return + } + // Simulate some work + time.Sleep(time.Microsecond) + p.Put(ctx, conn) + } + }() + } + wg.Wait() + + // Verify pool is still functional + stats := p.Stats() + if stats.TotalConns == 0 { + t.Error("Expected connections in pool after concurrent operations") + } + + // Pool should not exceed the maximum size + if stats.TotalConns > uint32(opt.PoolSize) { + t.Errorf("Pool exceeded maximum size: got %d, max %d", stats.TotalConns, opt.PoolSize) + } +} + +func TestPoolUsableConnections(t *testing.T) { + ctx := context.Background() + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil + }, + PoolSize: 5, + MinIdleConns: 0, // No minimum idle connections to simplify test + PoolTimeout: 1 * time.Second, + DialTimeout: 1 * time.Second, + } + p := NewConnPool(opt) + defer p.Close() + + // Test basic usable connection behavior + conn, err := p.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + // Note: Currently connections are created with usable=false by default + // This is expected behavior for the hitless upgrade system where connections + // need to be explicitly marked as usable after initialization + t.Logf("Connection usable status: %v", conn.IsUsable()) + + // Manually mark connection as usable for testing + conn.SetUsable(true) + if !conn.IsUsable() { + t.Error("Connection should be usable after SetUsable(true)") + } + + // Mark connection as unusable and put it back + conn.SetUsable(false) + p.Put(ctx, conn) + + // Note: Due to current pool implementation issues with unusable connections, + // we'll just verify the pool doesn't crash and can still create new connections + conn2, err := p.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection after putting back unusable connection: %v", err) + } + + // Clean up + p.Put(ctx, conn2) +} + +func TestPoolWaitBehavior(t *testing.T) { + ctx := context.Background() + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil + }, + PoolSize: 1, + PoolTimeout: 3 * time.Second, + } + p := NewConnPool(opt) + defer p.Close() + + wait := make(chan struct{}) + conn, err := p.Get(ctx) + if err != nil { + t.Fatalf("Failed to get first connection: %v", err) + } + + t.Logf("After first Get: Len=%d, IdleLen=%d, Stats=%+v", p.Len(), p.IdleLen(), p.Stats()) + + go func() { + t.Logf("Goroutine: calling Get()") + conn2, err := p.Get(ctx) // Keep reference to see what happens + if err != nil { + t.Logf("Goroutine: Get() failed: %v", err) + } else { + t.Logf("Goroutine: Get() succeeded, conn ID: %d", conn2.GetID()) + } + t.Logf("Goroutine: Get() completed") + wait <- struct{}{} + }() + + time.Sleep(time.Second) + t.Logf("After sleep, before Put: Len=%d, IdleLen=%d, Stats=%+v", p.Len(), p.IdleLen(), p.Stats()) + + t.Logf("Before Put: conn.IsUsable()=%v", conn.IsUsable()) + p.Put(ctx, conn) + t.Logf("After Put: Len=%d, IdleLen=%d, Stats=%+v", p.Len(), p.IdleLen(), p.Stats()) + + <-wait + t.Logf("After goroutine completion: Len=%d, IdleLen=%d, Stats=%+v", p.Len(), p.IdleLen(), p.Stats()) + + stats := p.Stats() + t.Logf("Final stats: IdleConns=%d, TotalConns=%d", stats.IdleConns, stats.TotalConns) + + // This is what the original test expects: + // The connection should still be tracked as "in use" even if reference is discarded + if stats.IdleConns != 0 { + t.Errorf("Expected IdleConns to be 0, got %d", stats.IdleConns) + } + if stats.TotalConns != 1 { + t.Errorf("Expected TotalConns to be 1, got %d", stats.TotalConns) + } +} diff --git a/internal/util/math.go b/internal/util/math.go new file mode 100644 index 0000000000..e707c47a64 --- /dev/null +++ b/internal/util/math.go @@ -0,0 +1,17 @@ +package util + +// Max returns the maximum of two integers +func Max(a, b int) int { + if a > b { + return a + } + return b +} + +// Min returns the minimum of two integers +func Min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/options.go b/options.go index a60519c9d0..aa1ae5e4fc 100644 --- a/options.go +++ b/options.go @@ -14,9 +14,11 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/hitless" + "github.com/redis/go-redis/v9/internal/interfaces" "github.com/redis/go-redis/v9/internal/pool" - "github.com/redis/go-redis/v9/push" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/push" ) // Limiter is the interface of a rate limiter or a circuit breaker. @@ -151,8 +153,11 @@ type Options struct { // - true for FIFO pool // - false for LIFO pool. // + // NOTE: If you are using HitlessUpgrades, this will be ignored and pool will always be FIFO. + // // Note that FIFO has slightly higher overhead compared to LIFO, // but it helps closing idle connections faster reducing the pool size. + // default: false PoolFIFO bool // PoolSize is the base number of socket connections. @@ -244,8 +249,19 @@ type Options struct { // When a node is marked as failing, it will be avoided for this duration. // Default is 15 seconds. FailingTimeoutSeconds int + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Enabled is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it. + HitlessUpgradeConfig *HitlessUpgradeConfig } +// HitlessUpgradeConfig provides configuration options for hitless upgrades. +// This is an alias to hitless.Config for convenience. +type HitlessUpgradeConfig = hitless.Config + func (opt *Options) init() { if opt.Addr == "" { opt.Addr = "localhost:6379" @@ -320,6 +336,16 @@ func (opt *Options) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + + opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize) + + // auto-detect endpoint type if not specified + endpointType := opt.HitlessUpgradeConfig.EndpointType + if endpointType == "" || endpointType == hitless.EndpointTypeAuto { + // Auto-detect endpoint type if not specified + endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil) + } + opt.HitlessUpgradeConfig.EndpointType = endpointType } func (opt *Options) clone() *Options { @@ -327,6 +353,12 @@ func (opt *Options) clone() *Options { return &clone } +// NewDialer returns a function that will be used as the default dialer +// when none is specified in Options.Dialer. +func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) { + return NewDialer(opt) +} + // NewDialer returns a function that will be used as the default dialer // when none is specified in Options.Dialer. func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) { @@ -609,15 +641,9 @@ func getUserPassword(u *url.URL) (string, string) { return user, password } -func newConnPool( - opt *Options, - dialer func(ctx context.Context, network, addr string) (net.Conn, error), -) *pool.ConnPool { - return pool.NewConnPool(&pool.Options{ - Dialer: func(ctx context.Context) (net.Conn, error) { - return dialer(ctx, opt.Network, opt.Addr) - }, - PoolFIFO: opt.PoolFIFO, +func newConnPoolConfig(opt *Options) *pool.Config { + return &pool.Config{ + PoolFIFO: opt.PoolFIFO || opt.HitlessUpgradeConfig.IsEnabled(), // Always FIFO with hitless upgrades PoolSize: opt.PoolSize, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, @@ -626,9 +652,29 @@ func newConnPool( MaxActiveConns: opt.MaxActiveConns, ConnMaxIdleTime: opt.ConnMaxIdleTime, ConnMaxLifetime: opt.ConnMaxLifetime, - // Pass protocol version for push notification optimization - Protocol: opt.Protocol, - ReadBufferSize: opt.ReadBufferSize, - WriteBufferSize: opt.WriteBufferSize, + } +} + +func newConnPool( + opt *Options, + dialer func(ctx context.Context, network, addr string) (net.Conn, error), + processor interfaces.ConnectionProcessor, +) *pool.ConnPool { + return pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return dialer(ctx, opt.Network, opt.Addr) + }, + PoolFIFO: opt.PoolFIFO || opt.HitlessUpgradeConfig.IsEnabled(), // Always FIFO with hitless upgrades + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnectionProcessor: processor, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, }) } diff --git a/osscluster.go b/osscluster.go index 0453bdd102..ee76dd8494 100644 --- a/osscluster.go +++ b/osscluster.go @@ -38,6 +38,7 @@ type ClusterOptions struct { ClientName string // NewClient creates a cluster node client with provided name and options. + // If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications. NewClient func(opt *Options) *Client // The maximum number of retries before giving up. Command is retried @@ -129,6 +130,14 @@ type ClusterOptions struct { // When a node is marked as failing, it will be avoided for this duration. // Default is 15 seconds. FailingTimeoutSeconds int + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Enabled is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it. + // The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless. + HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *ClusterOptions) init() { @@ -360,8 +369,9 @@ func (opt *ClusterOptions) clientOptions() *Options { // much use for ClusterSlots config). This means we cannot execute the // READONLY command against that node -- setting readOnly to false in such // situations in the options below will prevent that from happening. - readOnly: opt.ReadOnly && opt.ClusterSlots == nil, - UnstableResp3: opt.UnstableResp3, + readOnly: opt.ReadOnly && opt.ClusterSlots == nil, + UnstableResp3: opt.UnstableResp3, + HitlessUpgradeConfig: opt.HitlessUpgradeConfig, } } diff --git a/pool_pubsub_bench_test.go b/pool_pubsub_bench_test.go new file mode 100644 index 0000000000..ecd3a65a50 --- /dev/null +++ b/pool_pubsub_bench_test.go @@ -0,0 +1,375 @@ +// Pool and PubSub Benchmark Suite +// +// This file contains comprehensive benchmarks for both pool operations and PubSub initialization. +// It's designed to be run against different branches to compare performance. +// +// Usage Examples: +// # Run all benchmarks +// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go +// +// # Run only pool benchmarks +// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go +// +// # Run only PubSub benchmarks +// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go +// +// # Compare between branches +// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt +// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt +// benchcmp branch1.txt branch2.txt +// +// # Run with memory profiling +// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go +// +// # Run with CPU profiling +// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go + +package redis_test + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/internal/pool" +) + +// dummyDialer creates a mock connection for benchmarking +func dummyDialer(ctx context.Context) (net.Conn, error) { + return &dummyConn{}, nil +} + +// dummyConn implements net.Conn for benchmarking +type dummyConn struct{} + +func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil } +func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (c *dummyConn) Close() error { return nil } +func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} } +func (c *dummyConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} +} +func (c *dummyConn) SetDeadline(t time.Time) error { return nil } +func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil } + +// ============================================================================= +// POOL BENCHMARKS +// ============================================================================= + +// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations +func BenchmarkPoolGetPut(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: poolSize, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: 0, // Start with no idle connections + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns +func BenchmarkPoolGetPutWithMinIdle(b *testing.B) { + ctx := context.Background() + + configs := []struct { + poolSize int + minIdleConns int + }{ + {8, 2}, + {16, 4}, + {32, 8}, + {64, 16}, + } + + for _, config := range configs { + b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: config.poolSize, + MinIdleConns: config.minIdleConns, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency +func BenchmarkPoolConcurrentGetPut(b *testing.B) { + ctx := context.Background() + + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 32, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: 0, + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + // Test with different levels of concurrency + concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) { + b.SetParallelism(concurrency) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// ============================================================================= +// PUBSUB BENCHMARKS +// ============================================================================= + +// benchmarkClient creates a Redis client for benchmarking with mock dialer +func benchmarkClient(poolSize int) *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: "localhost:6379", // Mock address + DialTimeout: time.Second, + ReadTimeout: time.Second, + WriteTimeout: time.Second, + PoolSize: poolSize, + MinIdleConns: 0, // Start with no idle connections for consistent benchmarks + }) +} + +// BenchmarkPubSubCreation benchmarks PubSub creation and subscription +func BenchmarkPubSubCreation(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 4, 8, 16, 32} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + client := benchmarkClient(poolSize) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription +func BenchmarkPubSubPatternCreation(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 4, 8, 16, 32} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + client := benchmarkClient(poolSize) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.PSubscribe(ctx, "test-*") + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation +func BenchmarkPubSubConcurrentCreation(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + concurrencyLevels := []int{1, 2, 4, 8, 16} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + var wg sync.WaitGroup + semaphore := make(chan struct{}, concurrency) + + for i := 0; i < b.N; i++ { + wg.Add(1) + semaphore <- struct{}{} + + go func() { + defer wg.Done() + defer func() { <-semaphore }() + + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + }() + } + + wg.Wait() + }) + } +} + +// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels +func BenchmarkPubSubMultipleChannels(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(16) + defer client.Close() + + channelCounts := []int{1, 5, 10, 25, 50, 100} + + for _, channelCount := range channelCounts { + b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) { + // Prepare channel names + channels := make([]string, channelCount) + for i := 0; i < channelCount; i++ { + channels[i] = fmt.Sprintf("channel-%d", i) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.Subscribe(ctx, channels...) + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubReuse benchmarks reusing PubSub connections +func BenchmarkPubSubReuse(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(16) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Benchmark just the creation and closing of PubSub connections + // This simulates reuse patterns without requiring actual Redis operations + pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i)) + pubsub.Close() + } +} + +// ============================================================================= +// COMBINED BENCHMARKS +// ============================================================================= + +// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations +func BenchmarkPoolAndPubSubMixed(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Mix of pool stats collection and PubSub creation + if pb.Next() { + // Pool stats operation + stats := client.PoolStats() + _ = stats.Hits + stats.Misses // Use the stats to prevent optimization + } + + if pb.Next() { + // PubSub operation + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + } + }) +} + +// BenchmarkPoolStatsCollection benchmarks pool statistics collection +func BenchmarkPoolStatsCollection(b *testing.B) { + client := benchmarkClient(16) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + stats := client.PoolStats() + _ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization + } +} + +// BenchmarkPoolHighContention tests pool performance under high contention +func BenchmarkPoolHighContention(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // High contention Get/Put operations + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + }) +} diff --git a/pubsub.go b/pubsub.go index 75327dd2aa..62ec214514 100644 --- a/pubsub.go +++ b/pubsub.go @@ -42,6 +42,9 @@ type PubSub struct { // Push notification processor for handling generic push notifications pushProcessor push.NotificationProcessor + + // Cleanup callback for hitless upgrade tracking + onClose func() } func (c *PubSub) init() { @@ -157,6 +160,11 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo if c.cn != cn { return } + + if !cn.IsUsable() || cn.ShouldHandoff() { + c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable")) + } + if isBadConn(err, allowTimeout, c.opt.Addr) { c.reconnect(ctx, err) } @@ -189,6 +197,11 @@ func (c *PubSub) Close() error { c.closed = true close(c.exit) + // Call cleanup callback if set + if c.onClose != nil { + c.onClose() + } + return c.closeTheCn(pool.ErrClosed) } diff --git a/push/handler_context.go b/push/handler_context.go index 3bcf128f18..f89f87fa1b 100644 --- a/push/handler_context.go +++ b/push/handler_context.go @@ -1,8 +1,6 @@ package push -import ( - "github.com/redis/go-redis/v9/internal/pool" -) +// No imports needed for this file // NotificationHandlerContext provides context information about where a push notification was received. // This struct allows handlers to make informed decisions based on the source of the notification @@ -35,7 +33,12 @@ type NotificationHandlerContext struct { PubSub interface{} // Conn is the specific connection on which the notification was received. - Conn *pool.Conn + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *pool.Conn + // - *connectionAdapter (for hitless upgrades) + Conn interface{} // IsBlocking indicates if the notification was received on a blocking connection. IsBlocking bool diff --git a/push/processor_unit_test.go b/push/processor_unit_test.go new file mode 100644 index 0000000000..ce7990489f --- /dev/null +++ b/push/processor_unit_test.go @@ -0,0 +1,315 @@ +package push + +import ( + "context" + "testing" +) + +// TestProcessorCreation tests processor creation and initialization +func TestProcessorCreation(t *testing.T) { + t.Run("NewProcessor", func(t *testing.T) { + processor := NewProcessor() + if processor == nil { + t.Fatal("NewProcessor should not return nil") + } + if processor.registry == nil { + t.Error("Processor should have a registry") + } + }) + + t.Run("NewVoidProcessor", func(t *testing.T) { + voidProcessor := NewVoidProcessor() + if voidProcessor == nil { + t.Fatal("NewVoidProcessor should not return nil") + } + }) +} + +// TestProcessorHandlerManagement tests handler registration and retrieval +func TestProcessorHandlerManagement(t *testing.T) { + processor := NewProcessor() + handler := &UnitTestHandler{name: "test-handler"} + + t.Run("RegisterHandler", func(t *testing.T) { + err := processor.RegisterHandler("TEST", handler, false) + if err != nil { + t.Errorf("RegisterHandler should not error: %v", err) + } + + // Verify handler is registered + retrievedHandler := processor.GetHandler("TEST") + if retrievedHandler != handler { + t.Error("GetHandler should return the registered handler") + } + }) + + t.Run("RegisterProtectedHandler", func(t *testing.T) { + protectedHandler := &UnitTestHandler{name: "protected-handler"} + err := processor.RegisterHandler("PROTECTED", protectedHandler, true) + if err != nil { + t.Errorf("RegisterHandler should not error for protected handler: %v", err) + } + + // Verify handler is registered + retrievedHandler := processor.GetHandler("PROTECTED") + if retrievedHandler != protectedHandler { + t.Error("GetHandler should return the protected handler") + } + }) + + t.Run("GetNonExistentHandler", func(t *testing.T) { + handler := processor.GetHandler("NONEXISTENT") + if handler != nil { + t.Error("GetHandler should return nil for non-existent handler") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + err := processor.UnregisterHandler("TEST") + if err != nil { + t.Errorf("UnregisterHandler should not error: %v", err) + } + + // Verify handler is removed + retrievedHandler := processor.GetHandler("TEST") + if retrievedHandler != nil { + t.Error("GetHandler should return nil after unregistering") + } + }) + + t.Run("UnregisterProtectedHandler", func(t *testing.T) { + err := processor.UnregisterHandler("PROTECTED") + if err == nil { + t.Error("UnregisterHandler should error for protected handler") + } + + // Verify handler is still there + retrievedHandler := processor.GetHandler("PROTECTED") + if retrievedHandler == nil { + t.Error("Protected handler should not be removed") + } + }) +} + +// TestVoidProcessorBehavior tests void processor behavior +func TestVoidProcessorBehavior(t *testing.T) { + voidProcessor := NewVoidProcessor() + handler := &UnitTestHandler{name: "test-handler"} + + t.Run("GetHandler", func(t *testing.T) { + retrievedHandler := voidProcessor.GetHandler("ANY") + if retrievedHandler != nil { + t.Error("VoidProcessor GetHandler should always return nil") + } + }) + + t.Run("RegisterHandler", func(t *testing.T) { + err := voidProcessor.RegisterHandler("TEST", handler, false) + if err == nil { + t.Error("VoidProcessor RegisterHandler should return error") + } + + // Check error type + if !IsVoidProcessorError(err) { + t.Error("Error should be a VoidProcessorError") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + err := voidProcessor.UnregisterHandler("TEST") + if err == nil { + t.Error("VoidProcessor UnregisterHandler should return error") + } + + // Check error type + if !IsVoidProcessorError(err) { + t.Error("Error should be a VoidProcessorError") + } + }) +} + +// TestProcessPendingNotificationsNilReader tests handling of nil reader +func TestProcessPendingNotificationsNilReader(t *testing.T) { + t.Run("ProcessorWithNilReader", func(t *testing.T) { + processor := NewProcessor() + ctx := context.Background() + handlerCtx := NotificationHandlerContext{} + + err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil) + if err != nil { + t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err) + } + }) + + t.Run("VoidProcessorWithNilReader", func(t *testing.T) { + voidProcessor := NewVoidProcessor() + ctx := context.Background() + handlerCtx := NotificationHandlerContext{} + + err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil) + if err != nil { + t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err) + } + }) +} + +// TestWillHandleNotificationInClient tests the notification filtering logic +func TestWillHandleNotificationInClient(t *testing.T) { + testCases := []struct { + name string + notificationType string + shouldHandle bool + }{ + // Pub/Sub notifications (should be handled in client) + {"message", "message", true}, + {"pmessage", "pmessage", true}, + {"subscribe", "subscribe", true}, + {"unsubscribe", "unsubscribe", true}, + {"psubscribe", "psubscribe", true}, + {"punsubscribe", "punsubscribe", true}, + {"smessage", "smessage", true}, + {"ssubscribe", "ssubscribe", true}, + {"sunsubscribe", "sunsubscribe", true}, + + // Push notifications (should be handled by processor) + {"MOVING", "MOVING", false}, + {"MIGRATING", "MIGRATING", false}, + {"MIGRATED", "MIGRATED", false}, + {"FAILING_OVER", "FAILING_OVER", false}, + {"FAILED_OVER", "FAILED_OVER", false}, + {"custom", "custom", false}, + {"unknown", "unknown", false}, + {"empty", "", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := willHandleNotificationInClient(tc.notificationType) + if result != tc.shouldHandle { + t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle) + } + }) + } +} + +// TestProcessorErrorHandlingUnit tests error handling scenarios +func TestProcessorErrorHandlingUnit(t *testing.T) { + processor := NewProcessor() + + t.Run("RegisterNilHandler", func(t *testing.T) { + err := processor.RegisterHandler("TEST", nil, false) + if err == nil { + t.Error("RegisterHandler should error with nil handler") + } + + // Check error type + if !IsHandlerNilError(err) { + t.Error("Error should be a HandlerNilError") + } + }) + + t.Run("RegisterDuplicateHandler", func(t *testing.T) { + handler1 := &UnitTestHandler{name: "handler1"} + handler2 := &UnitTestHandler{name: "handler2"} + + // Register first handler + err := processor.RegisterHandler("DUPLICATE", handler1, false) + if err != nil { + t.Errorf("First RegisterHandler should not error: %v", err) + } + + // Try to register second handler with same name + err = processor.RegisterHandler("DUPLICATE", handler2, false) + if err == nil { + t.Error("RegisterHandler should error when registering duplicate handler") + } + + // Verify original handler is still there + retrievedHandler := processor.GetHandler("DUPLICATE") + if retrievedHandler != handler1 { + t.Error("Original handler should remain after failed duplicate registration") + } + }) + + t.Run("UnregisterNonExistentHandler", func(t *testing.T) { + err := processor.UnregisterHandler("NONEXISTENT") + if err != nil { + t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err) + } + }) +} + +// TestProcessorConcurrentAccess tests concurrent access to processor +func TestProcessorConcurrentAccess(t *testing.T) { + processor := NewProcessor() + + t.Run("ConcurrentRegisterAndGet", func(t *testing.T) { + done := make(chan bool, 2) + + // Goroutine 1: Register handlers + go func() { + defer func() { done <- true }() + for i := 0; i < 100; i++ { + handler := &UnitTestHandler{name: "concurrent-handler"} + processor.RegisterHandler("CONCURRENT", handler, false) + processor.UnregisterHandler("CONCURRENT") + } + }() + + // Goroutine 2: Get handlers + go func() { + defer func() { done <- true }() + for i := 0; i < 100; i++ { + processor.GetHandler("CONCURRENT") + } + }() + + // Wait for both goroutines to complete + <-done + <-done + }) +} + +// TestProcessorInterfaceCompliance tests interface compliance +func TestProcessorInterfaceCompliance(t *testing.T) { + t.Run("ProcessorImplementsInterface", func(t *testing.T) { + var _ NotificationProcessor = (*Processor)(nil) + }) + + t.Run("VoidProcessorImplementsInterface", func(t *testing.T) { + var _ NotificationProcessor = (*VoidProcessor)(nil) + }) +} + +// UnitTestHandler is a test implementation of NotificationHandler +type UnitTestHandler struct { + name string + lastNotification []interface{} + errorToReturn error + callCount int +} + +func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error { + h.callCount++ + h.lastNotification = notification + return h.errorToReturn +} + +// Helper methods for UnitTestHandler +func (h *UnitTestHandler) GetCallCount() int { + return h.callCount +} + +func (h *UnitTestHandler) GetLastNotification() []interface{} { + return h.lastNotification +} + +func (h *UnitTestHandler) SetErrorToReturn(err error) { + h.errorToReturn = err +} + +func (h *UnitTestHandler) Reset() { + h.callCount = 0 + h.lastNotification = nil + h.errorToReturn = nil +} diff --git a/redis.go b/redis.go index b3608c5ff8..f37d5d0f9f 100644 --- a/redis.go +++ b/redis.go @@ -10,8 +10,10 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/hitless" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" + "github.com/redis/go-redis/v9/internal/interfaces" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/push" @@ -204,19 +206,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e //------------------------------------------------------------------------------ type baseClient struct { - opt *Options - connPool pool.Pooler + opt *Options + optLock sync.RWMutex + connPool pool.Pooler + pubsubPool *pool.PubSubPool hooksMixin onClose func() error // hook called when client is closed // Push notification processing pushProcessor push.NotificationProcessor + + // Connection processor for handling connection lifecycle events + connectionProcessor interfaces.ConnectionProcessor + + // Hitless upgrade manager + hitlessManager *hitless.HitlessManager } func (c *baseClient) clone() *baseClient { - clone := *c - return &clone + clone := &baseClient{ + opt: c.opt, + connPool: c.connPool, + pubsubPool: c.pubsubPool, + onClose: c.onClose, + pushProcessor: c.pushProcessor, + connectionProcessor: c.connectionProcessor, + hitlessManager: c.hitlessManager, + } + return clone } func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { @@ -411,6 +429,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + var hitlessHandshakeErr error _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) @@ -424,12 +443,48 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { pipe.ClientSetName(ctx, c.opt.ClientName) } + // Enable maintenance notifications if hitless upgrades are configured + if c.opt.HitlessUpgradeConfig.IsEnabled() && c.opt.Protocol == 3 { + hitlessHandshakeErr = pipe.ClientMaintNotifications( + ctx, + true, + c.opt.HitlessUpgradeConfig.EndpointType.String(), + ).Err() + } + return nil }) if err != nil { return fmt.Errorf("failed to initialize connection options: %w", err) } + if hitlessHandshakeErr != nil { + c.optLock.RLock() + // handshake failed + switch c.opt.HitlessUpgradeConfig.Enabled { + case hitless.MaintNotificationsEnabled: + // enabled mode, fail the connection + return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr) + case hitless.MaintNotificationsAuto: + // auto mode, disable hitless upgrades and continue + c.disableHitlessUpgrades() + c.optLock.RUnlock() + c.optLock.Lock() + c.opt.HitlessUpgradeConfig.Enabled = hitless.MaintNotificationsDisabled + c.optLock.Unlock() + } + } else if c.opt.HitlessUpgradeConfig.IsEnabled() && c.opt.Protocol == 3 { + // handshake was executed successfully + // to make sure that the handshake will be executed on other connections as well if it was successfully + // executed on this connection, we will force the handshake to be executed on all connections + c.optLock.Lock() + c.opt.HitlessUpgradeConfig.Enabled = hitless.MaintNotificationsEnabled + c.optLock.Unlock() + } + + cn.SetUsable(true) + cn.Inited = true + if !c.opt.DisableIdentity && !c.opt.DisableIndentity { libName := "" libVer := Version() @@ -450,6 +505,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return c.opt.OnConnect(ctx, conn) } + // Set the connection initialization function for potential reconnections + cn.SetInitConnFunc(c.createInitConnFunc()) + return nil } @@ -593,6 +651,43 @@ func (c *baseClient) context(ctx context.Context) context.Context { return context.Background() } +// createInitConnFunc creates a connection initialization function that can be used for reconnections. +func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { + return func(ctx context.Context, cn *pool.Conn) error { + return c.initConn(ctx, cn) + } +} + +func (c *baseClient) enableHitlessUpgrades() error { + // Create hitless manager, used for operation tracking and push notification handling + manager, err := initializeHitlessManager(c, c.opt.HitlessUpgradeConfig) + if err != nil { + return err + } + // Set the manager reference + c.hitlessManager = manager + // Create the connection processor from the hitless manager, used for connection lifecycle events + // based on the connection flags (some set by the hitless manager and its notification handlers) + hitlessConnectionProcessor := manager.CreateConnectionProcessor(c.opt.Protocol, c.dialHook) + // Set reference to the pool for connection removal on handoff failure + hitlessConnectionProcessor.SetPool(c.connPool) + // Set the processor reference to the client + c.connectionProcessor = hitlessConnectionProcessor + return nil +} + +func (c *baseClient) disableHitlessUpgrades() error { + if c.hitlessManager != nil { + c.hitlessManager.Close() + c.hitlessManager = nil + } + if c.connectionProcessor != nil { + c.connectionProcessor.Shutdown(context.Background()) + c.connectionProcessor = nil + } + return nil +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be @@ -607,6 +702,11 @@ func (c *baseClient) Close() error { if err := c.connPool.Close(); err != nil && firstErr == nil { firstErr = err } + if c.pubsubPool != nil { + if err := c.pubsubPool.Close(); err != nil && firstErr == nil { + firstErr = err + } + } return firstErr } @@ -810,11 +910,32 @@ func NewClient(opt *Options) *Client { // Initialize push notification processor using shared helper // Use void processor for RESP2 connections (push notifications not available) c.pushProcessor = initializePushProcessor(opt) - - // Update options with the initialized push processor for connection pool + // Update options with the initialized push processor opt.PushNotificationProcessor = c.pushProcessor - c.connPool = newConnPool(opt, c.dialHook) + // Initialize hitless upgrades first if enabled to get the connection processor + if opt.HitlessUpgradeConfig.IsEnabled() { + if opt.Protocol != 3 { + internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol) + } else { + err := c.enableHitlessUpgrades() + if err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err) + if opt.HitlessUpgradeConfig.Enabled == hitless.MaintNotificationsEnabled { + // panic so we fail fast without breaking existing clients api + panic(fmt.Errorf("failed to enable hitless upgrades: %w", err)) + } + } + } + } + + // Create connection pool with the processor (nil if hitless upgrades disabled) + c.connPool = newConnPool(opt, c.dialHook, c.connectionProcessor) + + // Create separate PubSub pool + c.pubsubPool = pool.NewPubSubPool(newConnPoolConfig(opt), func(ctx context.Context) (net.Conn, error) { + return c.dialHook(ctx, opt.Network, opt.Addr) + }) return &c } @@ -851,6 +972,12 @@ func (c *Client) Options() *Options { return c.opt } +// GetHitlessManager returns the hitless manager instance for monitoring and control. +// Returns nil if hitless upgrades are not enabled. +func (c *Client) GetHitlessManager() *hitless.HitlessManager { + return c.hitlessManager +} + // initializePushProcessor initializes the push notification processor for any client type. // This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient. func initializePushProcessor(opt *Options) push.NotificationProcessor { @@ -887,6 +1014,8 @@ type PoolStats pool.Stats // PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { stats := c.connPool.Stats() + pubsubStats := c.pubsubPool.Stats() + stats.PubSubStats = pubsubStats return (*PoolStats)(stats) } @@ -921,11 +1050,25 @@ func (c *Client) TxPipeline() Pipeliner { func (c *Client) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + cn, err := c.pubsubPool.Get(ctx) + if err != nil { + return nil, err + } + + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + c.pubsubPool.Remove(ctx, cn, err) + return nil, err + } + + return cn, nil + }, + closeConn: func(cn *pool.Conn) error { + c.pubsubPool.Remove(context.Background(), cn, nil) + return nil }, - closeConn: c.connPool.CloseConn, pushProcessor: c.pushProcessor, } pubsub.init() @@ -1113,6 +1256,19 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica return push.NotificationHandlerContext{ Client: c, ConnPool: c.connPool, - Conn: cn, + Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access + } +} + +// initializeHitlessManager initializes hitless upgrade manager for a client. +func initializeHitlessManager(client *baseClient, config *HitlessUpgradeConfig) (*hitless.HitlessManager, error) { + // Create client adapter + clientAdapterInstance := NewClientAdapter(client) + + // Create hitless manager directly + manager, err := hitless.NewHitlessManager(clientAdapterInstance, config) + if err != nil { + return nil, err } + return manager, nil } diff --git a/sentinel.go b/sentinel.go index 2aa61a7e30..00588ac079 100644 --- a/sentinel.go +++ b/sentinel.go @@ -16,8 +16,8 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/rand" - "github.com/redis/go-redis/v9/push" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/push" ) //------------------------------------------------------------------------------ @@ -139,6 +139,14 @@ type FailoverOptions struct { FailingTimeoutSeconds int UnstableResp3 bool + + // Hitless is not supported for FailoverClients at the moment + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Enabled is not "disabled", the client will handle + // upgrade notifications gracefully and manage connection/pool state transitions + // seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are disabled. + //HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *FailoverOptions) clientOptions() *Options { @@ -455,8 +463,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt.Dialer = masterReplicaDialer(failover) opt.init() - var connPool *pool.ConnPool - rdb := &Client{ baseClient: &baseClient{ opt: opt, @@ -468,15 +474,22 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { // Use void processor by default for RESP2 connections rdb.pushProcessor = initializePushProcessor(opt) - connPool = newConnPool(opt, rdb.dialHook) - rdb.connPool = connPool + rdb.connPool = newConnPool(opt, rdb.dialHook, rdb.connectionProcessor) + + // Create separate PubSub pool + rdb.pubsubPool = pool.NewPubSubPool(newConnPoolConfig(opt), func(ctx context.Context) (net.Conn, error) { + return rdb.dialHook(ctx, opt.Network, opt.Addr) + }) + rdb.onClose = rdb.wrappedOnClose(failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { - _ = connPool.Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != addr - }) + if connPool, ok := rdb.connPool.(*pool.ConnPool); ok { + _ = connPool.Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr + }) + } } failover.mu.Unlock() @@ -542,7 +555,12 @@ func NewSentinelClient(opt *Options) *SentinelClient { dial: c.baseClient.dial, process: c.baseClient.process, }) - c.connPool = newConnPool(opt, c.dialHook) + c.connPool = newConnPool(opt, c.dialHook, c.connectionProcessor) + + // Create separate PubSub pool + c.pubsubPool = pool.NewPubSubPool(newConnPoolConfig(opt), func(ctx context.Context) (net.Conn, error) { + return c.dialHook(ctx, opt.Network, opt.Addr) + }) return c } @@ -569,11 +587,26 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { func (c *SentinelClient) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + cn, err := c.pubsubPool.Get(ctx) + if err != nil { + return nil, err + } + + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + c.pubsubPool.Remove(ctx, cn, err) + return nil, err + } + + return cn, nil + }, + closeConn: func(cn *pool.Conn) error { + c.pubsubPool.Remove(context.Background(), cn, nil) + return nil }, - closeConn: c.connPool.CloseConn, + pushProcessor: c.pushProcessor, } pubsub.init() return pubsub diff --git a/timeout_example_test.go b/timeout_example_test.go new file mode 100644 index 0000000000..50a452bdac --- /dev/null +++ b/timeout_example_test.go @@ -0,0 +1,364 @@ +package redis_test + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/hitless" + "github.com/redis/go-redis/v9/push" +) + +// mockConnectionAdapter implements ConnectionWithRelaxedTimeout for testing +type mockConnectionAdapter struct { + relaxedReadTimeout time.Duration + relaxedWriteTimeout time.Duration + relaxedDeadline time.Time +} + +func (m *mockConnectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + m.relaxedReadTimeout = readTimeout + m.relaxedWriteTimeout = writeTimeout + m.relaxedDeadline = time.Time{} // No deadline +} + +func (m *mockConnectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + m.relaxedReadTimeout = readTimeout + m.relaxedWriteTimeout = writeTimeout + m.relaxedDeadline = deadline +} + +func (m *mockConnectionAdapter) ClearRelaxedTimeout() { + m.relaxedReadTimeout = 0 + m.relaxedWriteTimeout = 0 + m.relaxedDeadline = time.Time{} +} + +func (m *mockConnectionAdapter) HasRelaxedTimeout() bool { + return m.relaxedReadTimeout > 0 || m.relaxedWriteTimeout > 0 +} + +// TestTimeoutAdjustmentDemo demonstrates how timeouts are adjusted during hitless upgrades. +func TestTimeoutAdjustmentDemo(t *testing.T) { + // Create a client with hitless upgrades enabled + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, + ReadTimeout: 5 * time.Second, // Original read timeout + WriteTimeout: 3 * time.Second, // Original write timeout + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + RelaxedTimeout: 20 * time.Second, // 20s timeout during migrations + HandoffTimeout: 15 * time.Second, + LogLevel: 2, // Info level to see timeout adjustments + }, + }) + defer client.Close() + + // Get the hitless integration + integration := client.GetHitlessManager() + if integration == nil { + t.Skip("Hitless upgrades not available (likely RESP2 or disabled)") + } + + ctx := context.Background() + + // Check initial state and timeouts + t.Logf("Initial state: %s", integration.GetState()) + t.Logf("Initial client read timeout: %v", client.Options().ReadTimeout) + t.Logf("Initial client write timeout: %v", client.Options().WriteTimeout) + + // Simulate a MIGRATING notification by calling the integration directly + // Note: In real usage, this would be called automatically by push notification handlers + t.Log("Simulating MIGRATING notification...") + + // Create a mock connection that implements ConnectionWithRelaxedTimeout + mockConn := &mockConnectionAdapter{ + relaxedReadTimeout: 0, + relaxedWriteTimeout: 0, + } + + // Create a mock handler context (in real usage, this comes from the push notification system) + handlerCtx := push.NotificationHandlerContext{ + Client: client, + ConnPool: nil, // Not needed for this test + Conn: mockConn, // Mock connection that supports relaxed timeouts + } + + // Simulate MIGRATING push notification (no slot information needed) + migratingNotification := []interface{}{"MIGRATING", "node1"} // Simple format without slot + + // Get the MIGRATING handler + handler := client.GetPushNotificationHandler("MIGRATING") + if handler == nil { + t.Fatal("MIGRATING handler not found") + } + + // Handle the MIGRATING notification + err := handler.HandlePushNotification(ctx, handlerCtx, migratingNotification) + if err != nil { + t.Fatalf("Failed to handle MIGRATING notification: %v", err) + } + + // Check state and timeouts after migration starts + t.Logf("State after MIGRATING: %s", integration.GetState()) + t.Logf("Client read timeout during migration: %v", client.Options().ReadTimeout) + t.Logf("Client write timeout during migration: %v", client.Options().WriteTimeout) + + // With per-connection timeouts, global client timeouts should remain unchanged + // Only the specific connection that received the notification gets the relaxed timeout + if client.Options().ReadTimeout != 5*time.Second { + t.Errorf("Expected global read timeout to remain 5s, got %v", client.Options().ReadTimeout) + } + if client.Options().WriteTimeout != 3*time.Second { + t.Errorf("Expected global write timeout to remain 3s, got %v", client.Options().WriteTimeout) + } + + // State should remain idle since we're not doing global state management for migration/failover + if integration.GetState() != hitless.StateIdle { + t.Logf("Note: State is %s (per-connection approach doesn't change global state)", integration.GetState()) + } + + // Simulate a MIGRATED notification + t.Log("Simulating MIGRATED notification...") + // Simulate MIGRATED push notification + migratedNotification := []interface{}{"MIGRATED", "1"} // slot 1 completed + + // Get the MIGRATED handler + migratedHandler := client.GetPushNotificationHandler("MIGRATED") + if migratedHandler == nil { + t.Fatal("MIGRATED handler not found") + } + + // Handle the MIGRATED notification + err = migratedHandler.HandlePushNotification(ctx, handlerCtx, migratedNotification) + if err != nil { + t.Fatalf("Failed to handle MIGRATED notification: %v", err) + } + + // Check state and timeouts after migration completes + t.Logf("State after MIGRATED: %s", integration.GetState()) + t.Logf("Client read timeout after migration: %v", client.Options().ReadTimeout) + t.Logf("Client write timeout after migration: %v", client.Options().WriteTimeout) + + // Global client timeouts should still be unchanged (per-connection approach) + if client.Options().ReadTimeout != 5*time.Second { + t.Errorf("Expected global read timeout to remain 5s, got %v", client.Options().ReadTimeout) + } + if client.Options().WriteTimeout != 3*time.Second { + t.Errorf("Expected global write timeout to remain 3s, got %v", client.Options().WriteTimeout) + } + + t.Log("βœ… Timeout adjustment demonstration completed successfully") +} + +// TestTimeoutAdjustmentWithFailover demonstrates timeout adjustment during failover. +func TestTimeoutAdjustmentWithFailover(t *testing.T) { + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, + ReadTimeout: 4 * time.Second, + WriteTimeout: 2 * time.Second, + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + RelaxedTimeout: 25 * time.Second, // 25s timeout during failover + HandoffTimeout: 15 * time.Second, + LogLevel: 2, + }, + }) + defer client.Close() + + integration := client.GetHitlessManager() + if integration == nil { + t.Skip("Hitless upgrades not available") + } + + ctx := context.Background() + + // Start failover + t.Log("Starting failover simulation...") + + // Create a mock connection that implements ConnectionWithRelaxedTimeout + mockConn := &mockConnectionAdapter{ + relaxedReadTimeout: 0, + relaxedWriteTimeout: 0, + } + + // Create a mock handler context + handlerCtx := push.NotificationHandlerContext{ + Client: client, + ConnPool: nil, // Not needed for this test + Conn: mockConn, // Mock connection that supports relaxed timeouts + } + + // Simulate FAILING_OVER push notification (no slot information needed) + failingOverNotification := []interface{}{"FAILING_OVER", "node2"} // Simple format without slot + + // Get the FAILING_OVER handler + handler := client.GetPushNotificationHandler("FAILING_OVER") + if handler == nil { + t.Fatal("FAILING_OVER handler not found") + } + + // Handle the FAILING_OVER notification + err := handler.HandlePushNotification(ctx, handlerCtx, failingOverNotification) + if err != nil { + t.Fatalf("Failed to handle FAILING_OVER notification: %v", err) + } + + // Check that relaxed timeouts were applied to the connection + // Note: Hitless upgrades apply relaxed timeouts per-connection, not globally to the client + expectedTimeout := 25 * time.Second // RelaxedTimeout + + if !mockConn.HasRelaxedTimeout() { + t.Error("Expected relaxed timeout to be applied to connection") + } + + if mockConn.relaxedReadTimeout != expectedTimeout { + t.Errorf("Expected connection read timeout %v during failover, got %v", expectedTimeout, mockConn.relaxedReadTimeout) + } + + if mockConn.relaxedWriteTimeout != expectedTimeout { + t.Errorf("Expected connection write timeout %v during failover, got %v", expectedTimeout, mockConn.relaxedWriteTimeout) + } + + t.Logf("Failover state: %s", integration.GetState()) + t.Logf("Connection relaxed timeouts: read=%v, write=%v", + mockConn.relaxedReadTimeout, mockConn.relaxedWriteTimeout) + + // Complete failover + // Simulate FAILED_OVER push notification + failedOverNotification := []interface{}{"FAILED_OVER", "node2"} // Simple format without slot + + // Get the FAILED_OVER handler + failedOverHandler := client.GetPushNotificationHandler("FAILED_OVER") + if failedOverHandler == nil { + t.Fatal("FAILED_OVER handler not found") + } + + // Handle the FAILED_OVER notification + err = failedOverHandler.HandlePushNotification(ctx, handlerCtx, failedOverNotification) + if err != nil { + t.Fatalf("Failed to handle FAILED_OVER notification: %v", err) + } + + // Verify that relaxed timeouts were cleared from the connection + if mockConn.HasRelaxedTimeout() { + t.Error("Expected relaxed timeout to be cleared from connection after FAILED_OVER") + } + + t.Logf("Final state: %s", integration.GetState()) + t.Logf("Connection timeouts cleared: relaxed=%v", mockConn.HasRelaxedTimeout()) +} + +// TestMultipleOperationsTimeoutManagement demonstrates timeout management with overlapping operations. +func TestMultipleOperationsTimeoutManagement(t *testing.T) { + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, + ReadTimeout: 5 * time.Second, + WriteTimeout: 3 * time.Second, + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Enabled: hitless.MaintNotificationsEnabled, + RelaxedTimeout: 15 * time.Second, // 15s timeout during operations + HandoffTimeout: 15 * time.Second, + LogLevel: 2, + }, + }) + defer client.Close() + + integration := client.GetHitlessManager() + if integration == nil { + t.Skip("Hitless upgrades not available") + } + + ctx := context.Background() + + // Start migration + t.Log("Starting migration...") + + // Create mock connections that implement ConnectionWithRelaxedTimeout + mockConn1 := &mockConnectionAdapter{ + relaxedReadTimeout: 0, + relaxedWriteTimeout: 0, + } + mockConn2 := &mockConnectionAdapter{ + relaxedReadTimeout: 0, + relaxedWriteTimeout: 0, + } + + // Create mock handler contexts for different connections + migrationCtx := push.NotificationHandlerContext{ + Client: client, + ConnPool: nil, // Not needed for this test + Conn: mockConn1, // Mock connection 1 + } + failoverCtx := push.NotificationHandlerContext{ + Client: client, + ConnPool: nil, // Not needed for this test + Conn: mockConn2, // Mock connection 2 + } + + // Simulate MIGRATING push notification (no slot information needed) + migratingNotification := []interface{}{"MIGRATING", "node3"} // Simple format without slot + + // Get the MIGRATING handler + migratingHandler := client.GetPushNotificationHandler("MIGRATING") + if migratingHandler == nil { + t.Fatal("MIGRATING handler not found") + } + + // Handle the MIGRATING notification + err := migratingHandler.HandlePushNotification(ctx, migrationCtx, migratingNotification) + if err != nil { + t.Fatalf("Failed to handle MIGRATING notification: %v", err) + } + + // Check that relaxed timeouts were applied to the first connection + expectedTimeout := 15 * time.Second // RelaxedTimeout + if !mockConn1.HasRelaxedTimeout() { + t.Error("Expected relaxed timeout to be applied to migration connection") + } + if mockConn1.relaxedReadTimeout != expectedTimeout { + t.Errorf("Expected migration connection timeout %v, got %v", expectedTimeout, mockConn1.relaxedReadTimeout) + } + + // Start failover while migration is in progress + t.Log("Starting failover while migration is in progress...") + // Simulate FAILING_OVER push notification + failingOverNotification2 := []interface{}{"FAILING_OVER", "node4"} // Simple format without slot + + // Get the FAILING_OVER handler + failingOverHandler2 := client.GetPushNotificationHandler("FAILING_OVER") + if failingOverHandler2 == nil { + t.Fatal("FAILING_OVER handler not found") + } + + // Handle the FAILING_OVER notification + err = failingOverHandler2.HandlePushNotification(ctx, failoverCtx, failingOverNotification2) + if err != nil { + t.Fatalf("Failed to handle FAILING_OVER notification: %v", err) + } + + // Check that relaxed timeouts were applied to the second connection too + if !mockConn2.HasRelaxedTimeout() { + t.Error("Expected relaxed timeout to be applied to failover connection") + } + if mockConn2.relaxedReadTimeout != expectedTimeout { + t.Errorf("Expected failover connection timeout %v, got %v", expectedTimeout, mockConn2.relaxedReadTimeout) + } + + // Note: In the current implementation, timeout management is handled + // per-connection by the hitless system. Each connection that receives + // a MIGRATING/FAILING_OVER notification gets relaxed timeouts applied. + + // Verify that hitless manager is tracking operations + state := integration.GetCurrentState() + t.Logf("Current hitless state: %v", state) + + t.Logf("Connection 1 (migration) has relaxed timeout: %v", mockConn1.HasRelaxedTimeout()) + t.Logf("Connection 2 (failover) has relaxed timeout: %v", mockConn2.HasRelaxedTimeout()) + + t.Log("βœ… Multiple operations timeout management completed successfully") +}