|
| 1 | +package hitless |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "sync" |
| 6 | + "sync/atomic" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/redis/go-redis/v9/internal" |
| 10 | +) |
| 11 | + |
| 12 | +// CircuitBreakerState represents the state of a circuit breaker |
| 13 | +type CircuitBreakerState int32 |
| 14 | + |
| 15 | +const ( |
| 16 | + // CircuitBreakerClosed - normal operation, requests allowed |
| 17 | + CircuitBreakerClosed CircuitBreakerState = iota |
| 18 | + // CircuitBreakerOpen - failing fast, requests rejected |
| 19 | + CircuitBreakerOpen |
| 20 | + // CircuitBreakerHalfOpen - testing if service recovered |
| 21 | + CircuitBreakerHalfOpen |
| 22 | +) |
| 23 | + |
| 24 | +func (s CircuitBreakerState) String() string { |
| 25 | + switch s { |
| 26 | + case CircuitBreakerClosed: |
| 27 | + return "closed" |
| 28 | + case CircuitBreakerOpen: |
| 29 | + return "open" |
| 30 | + case CircuitBreakerHalfOpen: |
| 31 | + return "half-open" |
| 32 | + default: |
| 33 | + return "unknown" |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling |
| 38 | +type CircuitBreaker struct { |
| 39 | + // Configuration |
| 40 | + failureThreshold int // Number of failures before opening |
| 41 | + resetTimeout time.Duration // How long to stay open before testing |
| 42 | + maxRequests int // Max requests allowed in half-open state |
| 43 | + |
| 44 | + // State tracking (atomic for lock-free access) |
| 45 | + state atomic.Int32 // CircuitBreakerState |
| 46 | + failures atomic.Int64 // Current failure count |
| 47 | + successes atomic.Int64 // Success count in half-open state |
| 48 | + requests atomic.Int64 // Request count in half-open state |
| 49 | + lastFailureTime atomic.Int64 // Unix timestamp of last failure |
| 50 | + lastSuccessTime atomic.Int64 // Unix timestamp of last success |
| 51 | + |
| 52 | + // Endpoint identification |
| 53 | + endpoint string |
| 54 | + config *Config |
| 55 | +} |
| 56 | + |
| 57 | +// newCircuitBreaker creates a new circuit breaker for an endpoint |
| 58 | +func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker { |
| 59 | + // Use sensible defaults if not configured |
| 60 | + failureThreshold := 10 |
| 61 | + resetTimeout := 500 * time.Millisecond |
| 62 | + maxRequests := 10 |
| 63 | + |
| 64 | + // These could be added to Config in the future without breaking API |
| 65 | + // For now, use internal defaults that work well |
| 66 | + |
| 67 | + return &CircuitBreaker{ |
| 68 | + failureThreshold: failureThreshold, |
| 69 | + resetTimeout: resetTimeout, |
| 70 | + maxRequests: maxRequests, |
| 71 | + endpoint: endpoint, |
| 72 | + config: config, |
| 73 | + state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0) |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +// IsOpen returns true if the circuit breaker is open (rejecting requests) |
| 78 | +func (cb *CircuitBreaker) IsOpen() bool { |
| 79 | + state := CircuitBreakerState(cb.state.Load()) |
| 80 | + |
| 81 | + if state == CircuitBreakerOpen { |
| 82 | + // Check if we should transition to half-open |
| 83 | + if cb.shouldAttemptReset() { |
| 84 | + if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) { |
| 85 | + cb.requests.Store(0) |
| 86 | + cb.successes.Store(0) |
| 87 | + if cb.config != nil && cb.config.LogLevel.InfoOrAbove() { |
| 88 | + internal.Logger.Printf(context.Background(), |
| 89 | + "hitless: circuit breaker for %s transitioning to half-open", cb.endpoint) |
| 90 | + } |
| 91 | + return false // Now in half-open state, allow requests |
| 92 | + } |
| 93 | + } |
| 94 | + return true // Still open |
| 95 | + } |
| 96 | + |
| 97 | + return false |
| 98 | +} |
| 99 | + |
| 100 | +// shouldAttemptReset checks if enough time has passed to attempt reset |
| 101 | +func (cb *CircuitBreaker) shouldAttemptReset() bool { |
| 102 | + lastFailure := time.Unix(cb.lastFailureTime.Load(), 0) |
| 103 | + return time.Since(lastFailure) >= cb.resetTimeout |
| 104 | +} |
| 105 | + |
| 106 | +// Execute runs the given function with circuit breaker protection |
| 107 | +func (cb *CircuitBreaker) Execute(fn func() error) error { |
| 108 | + // Fast path: if circuit is open, fail immediately |
| 109 | + if cb.IsOpen() { |
| 110 | + return ErrCircuitBreakerOpen |
| 111 | + } |
| 112 | + |
| 113 | + state := CircuitBreakerState(cb.state.Load()) |
| 114 | + |
| 115 | + // In half-open state, limit the number of requests |
| 116 | + if state == CircuitBreakerHalfOpen { |
| 117 | + requests := cb.requests.Add(1) |
| 118 | + if requests > int64(cb.maxRequests) { |
| 119 | + cb.requests.Add(-1) // Revert the increment |
| 120 | + return ErrCircuitBreakerOpen |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + // Execute the function |
| 125 | + err := fn() |
| 126 | + |
| 127 | + if err != nil { |
| 128 | + cb.recordFailure() |
| 129 | + return err |
| 130 | + } |
| 131 | + |
| 132 | + cb.recordSuccess() |
| 133 | + return nil |
| 134 | +} |
| 135 | + |
| 136 | +// recordFailure records a failure and potentially opens the circuit |
| 137 | +func (cb *CircuitBreaker) recordFailure() { |
| 138 | + cb.lastFailureTime.Store(time.Now().Unix()) |
| 139 | + failures := cb.failures.Add(1) |
| 140 | + |
| 141 | + state := CircuitBreakerState(cb.state.Load()) |
| 142 | + |
| 143 | + switch state { |
| 144 | + case CircuitBreakerClosed: |
| 145 | + if failures >= int64(cb.failureThreshold) { |
| 146 | + if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) { |
| 147 | + if cb.config != nil && cb.config.LogLevel.WarnOrAbove() { |
| 148 | + internal.Logger.Printf(context.Background(), |
| 149 | + "hitless: circuit breaker opened for endpoint %s after %d failures", |
| 150 | + cb.endpoint, failures) |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + case CircuitBreakerHalfOpen: |
| 155 | + // Any failure in half-open state immediately opens the circuit |
| 156 | + if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) { |
| 157 | + if cb.config != nil && cb.config.LogLevel.WarnOrAbove() { |
| 158 | + internal.Logger.Printf(context.Background(), |
| 159 | + "hitless: circuit breaker reopened for endpoint %s due to failure in half-open state", |
| 160 | + cb.endpoint) |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | +} |
| 165 | + |
| 166 | +// recordSuccess records a success and potentially closes the circuit |
| 167 | +func (cb *CircuitBreaker) recordSuccess() { |
| 168 | + cb.lastSuccessTime.Store(time.Now().Unix()) |
| 169 | + |
| 170 | + state := CircuitBreakerState(cb.state.Load()) |
| 171 | + |
| 172 | + if state == CircuitBreakerClosed { |
| 173 | + // Reset failure count on success in closed state |
| 174 | + cb.failures.Store(0) |
| 175 | + } else if state == CircuitBreakerHalfOpen { |
| 176 | + successes := cb.successes.Add(1) |
| 177 | + |
| 178 | + // If we've had enough successful requests, close the circuit |
| 179 | + if successes >= int64(cb.maxRequests) { |
| 180 | + if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) { |
| 181 | + cb.failures.Store(0) |
| 182 | + if cb.config != nil && cb.config.LogLevel.InfoOrAbove() { |
| 183 | + internal.Logger.Printf(context.Background(), |
| 184 | + "hitless: circuit breaker closed for endpoint %s after %d successful requests", |
| 185 | + cb.endpoint, successes) |
| 186 | + } |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | +} |
| 191 | + |
| 192 | +// GetState returns the current state of the circuit breaker |
| 193 | +func (cb *CircuitBreaker) GetState() CircuitBreakerState { |
| 194 | + return CircuitBreakerState(cb.state.Load()) |
| 195 | +} |
| 196 | + |
| 197 | +// GetStats returns current statistics for monitoring |
| 198 | +func (cb *CircuitBreaker) GetStats() CircuitBreakerStats { |
| 199 | + return CircuitBreakerStats{ |
| 200 | + Endpoint: cb.endpoint, |
| 201 | + State: cb.GetState(), |
| 202 | + Failures: cb.failures.Load(), |
| 203 | + Successes: cb.successes.Load(), |
| 204 | + Requests: cb.requests.Load(), |
| 205 | + LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0), |
| 206 | + LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0), |
| 207 | + } |
| 208 | +} |
| 209 | + |
| 210 | +// CircuitBreakerStats provides statistics about a circuit breaker |
| 211 | +type CircuitBreakerStats struct { |
| 212 | + Endpoint string |
| 213 | + State CircuitBreakerState |
| 214 | + Failures int64 |
| 215 | + Successes int64 |
| 216 | + Requests int64 |
| 217 | + LastFailureTime time.Time |
| 218 | + LastSuccessTime time.Time |
| 219 | +} |
| 220 | + |
| 221 | +// CircuitBreakerManager manages circuit breakers for multiple endpoints |
| 222 | +type CircuitBreakerManager struct { |
| 223 | + breakers sync.Map // map[string]*CircuitBreaker |
| 224 | + config *Config |
| 225 | +} |
| 226 | + |
| 227 | +// newCircuitBreakerManager creates a new circuit breaker manager |
| 228 | +func newCircuitBreakerManager(config *Config) *CircuitBreakerManager { |
| 229 | + return &CircuitBreakerManager{ |
| 230 | + config: config, |
| 231 | + } |
| 232 | +} |
| 233 | + |
| 234 | +// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary |
| 235 | +func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker { |
| 236 | + if breaker, ok := cbm.breakers.Load(endpoint); ok { |
| 237 | + return breaker.(*CircuitBreaker) |
| 238 | + } |
| 239 | + |
| 240 | + // Create new circuit breaker |
| 241 | + newBreaker := newCircuitBreaker(endpoint, cbm.config) |
| 242 | + actual, _ := cbm.breakers.LoadOrStore(endpoint, newBreaker) |
| 243 | + return actual.(*CircuitBreaker) |
| 244 | +} |
| 245 | + |
| 246 | +// GetAllStats returns statistics for all circuit breakers |
| 247 | +func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats { |
| 248 | + var stats []CircuitBreakerStats |
| 249 | + cbm.breakers.Range(func(key, value interface{}) bool { |
| 250 | + breaker := value.(*CircuitBreaker) |
| 251 | + stats = append(stats, breaker.GetStats()) |
| 252 | + return true |
| 253 | + }) |
| 254 | + return stats |
| 255 | +} |
| 256 | + |
| 257 | +// Reset resets all circuit breakers (useful for testing) |
| 258 | +func (cbm *CircuitBreakerManager) Reset() { |
| 259 | + cbm.breakers.Range(func(key, value interface{}) bool { |
| 260 | + breaker := value.(*CircuitBreaker) |
| 261 | + breaker.state.Store(int32(CircuitBreakerClosed)) |
| 262 | + breaker.failures.Store(0) |
| 263 | + breaker.successes.Store(0) |
| 264 | + breaker.requests.Store(0) |
| 265 | + breaker.lastFailureTime.Store(0) |
| 266 | + breaker.lastSuccessTime.Store(0) |
| 267 | + return true |
| 268 | + }) |
| 269 | +} |
0 commit comments