diff --git a/Sources/Realtime/Auth/AuthTokenManager.swift b/Sources/Realtime/Auth/AuthTokenManager.swift new file mode 100644 index 000000000..a80039290 --- /dev/null +++ b/Sources/Realtime/Auth/AuthTokenManager.swift @@ -0,0 +1,87 @@ +// +// AuthTokenManager.swift +// Realtime +// +// Created on 17/01/25. +// + +import Foundation + +/// Manages authentication token lifecycle and distribution. +/// +/// This actor provides a single source of truth for the current authentication token, +/// handling both direct token assignment and token provider callbacks. +actor AuthTokenManager { + // MARK: - Properties + + private var currentToken: String? + private let tokenProvider: (@Sendable () async throws -> String?)? + + // MARK: - Initialization + + init( + initialToken: String?, + tokenProvider: (@Sendable () async throws -> String?)? + ) { + self.currentToken = initialToken + self.tokenProvider = tokenProvider + } + + // MARK: - Public API + + /// Get current token, calling provider if needed. + /// + /// If no current token is set, this will attempt to fetch from the token provider. + /// + /// - Returns: The current authentication token, or nil if unavailable + func getCurrentToken() async -> String? { + // Return current token if available + if let token = currentToken { + return token + } + + // Try to get from provider + if let provider = tokenProvider { + let token = try? await provider() + currentToken = token + return token + } + + return nil + } + + /// Update token and return if it changed. + /// + /// - Parameter token: The new token to set, or nil to clear + /// - Returns: True if the token changed, false if it's the same + func updateToken(_ token: String?) async -> Bool { + guard token != currentToken else { + return false + } + + currentToken = token + return true + } + + /// Refresh token from provider if available. + /// + /// This forces a call to the token provider even if a current token exists. + /// + /// - Returns: The refreshed token, or current token if no provider + func refreshToken() async -> String? { + guard let provider = tokenProvider else { + return currentToken + } + + let token = try? await provider() + currentToken = token + return token + } + + /// Get the current token without calling the provider. + /// + /// - Returns: The currently stored token, or nil + var token: String? { + currentToken + } +} diff --git a/Sources/Realtime/Connection/ConnectionStateMachine.swift b/Sources/Realtime/Connection/ConnectionStateMachine.swift new file mode 100644 index 000000000..42657f9b6 --- /dev/null +++ b/Sources/Realtime/Connection/ConnectionStateMachine.swift @@ -0,0 +1,191 @@ +// +// ConnectionStateMachine.swift +// Realtime +// +// Created on 17/01/25. +// + +import Foundation +import Helpers + +/// Manages WebSocket connection lifecycle with clear state transitions. +/// +/// This actor ensures thread-safe connection management and prevents race conditions +/// by enforcing valid state transitions through Swift's type system. +actor ConnectionStateMachine { + /// Represents the possible states of a WebSocket connection + enum State: Sendable { + case disconnected + case connecting(Task) + case connected(any WebSocket) + case reconnecting(Task, reason: String) + } + + // MARK: - Properties + + private(set) var state: State = .disconnected + + private let transport: WebSocketTransport + private let url: URL + private let headers: [String: String] + private let reconnectDelay: TimeInterval + private let logger: (any SupabaseLogger)? + + // MARK: - Initialization + + init( + transport: @escaping WebSocketTransport, + url: URL, + headers: [String: String], + reconnectDelay: TimeInterval, + logger: (any SupabaseLogger)? + ) { + self.transport = transport + self.url = url + self.headers = headers + self.reconnectDelay = reconnectDelay + self.logger = logger + } + + // MARK: - Public API + + /// Connect to WebSocket. Returns existing connection if already connected. + /// + /// This method is safe to call multiple times - it will reuse an existing connection + /// or wait for an in-progress connection attempt to complete. + /// + /// - Returns: The active WebSocket connection + /// - Throws: Connection errors from the transport layer + func connect() async throws -> any WebSocket { + switch state { + case .connected(let conn): + logger?.debug("Already connected to WebSocket") + return conn + + case .connecting(let task): + logger?.debug("Connection already in progress, waiting...") + try await task.value + // Recursively call to get the connection after task completes + return try await connect() + + case .reconnecting(let task, _): + logger?.debug("Reconnection in progress, waiting...") + try await task.value + return try await connect() + + case .disconnected: + logger?.debug("Initiating new connection") + return try await performConnection() + } + } + + /// Disconnect and clean up resources. + /// + /// - Parameter reason: Optional reason for disconnection + func disconnect(reason: String? = nil) { + switch state { + case .connected(let conn): + logger?.debug("Disconnecting from WebSocket: \(reason ?? "no reason")") + conn.close(code: nil, reason: reason) + state = .disconnected + + case .connecting(let task), .reconnecting(let task, _): + logger?.debug("Cancelling connection attempt: \(reason ?? "no reason")") + task.cancel() + state = .disconnected + + case .disconnected: + logger?.debug("Already disconnected") + } + } + + /// Handle connection error and initiate reconnect. + /// + /// - Parameter error: The error that caused the connection failure + func handleError(_ error: any Error) { + guard case .connected = state else { + logger?.debug("Ignoring error in non-connected state: \(error)") + return + } + + logger?.debug("Connection error, initiating reconnect: \(error.localizedDescription)") + initiateReconnect(reason: "error: \(error.localizedDescription)") + } + + /// Handle connection close. + /// + /// - Parameters: + /// - code: WebSocket close code + /// - reason: WebSocket close reason + func handleClose(code: Int?, reason: String?) { + let closeReason = "code: \(code?.description ?? "none"), reason: \(reason ?? "none")" + logger?.debug("Connection closed: \(closeReason)") + + disconnect(reason: reason) + } + + /// Handle disconnection event and initiate reconnect. + func handleDisconnected() { + guard case .connected = state else { return } + + logger?.debug("Connection disconnected, initiating reconnect") + initiateReconnect(reason: "disconnected") + } + + /// Get current connection if connected, nil otherwise. + var connection: (any WebSocket)? { + if case .connected(let conn) = state { + return conn + } + return nil + } + + /// Check if currently connected. + var isConnected: Bool { + if case .connected = state { + return true + } + return false + } + + // MARK: - Private Helpers + + private func performConnection() async throws -> any WebSocket { + let connectionTask = Task { + let conn = try await transport(url, headers) + state = .connected(conn) + } + + state = .connecting(connectionTask) + + do { + try await connectionTask.value + + // Get the connection that was just set + guard case .connected(let conn) = state else { + throw RealtimeError("Connection succeeded but state is invalid") + } + + return conn + } catch { + state = .disconnected + throw error + } + } + + private func initiateReconnect(reason: String) { + let reconnectTask = Task { + try await Task.sleep(nanoseconds: UInt64(reconnectDelay * 1_000_000_000)) + + if Task.isCancelled { + logger?.debug("Reconnect cancelled") + return + } + + logger?.debug("Attempting to reconnect...") + _ = try await performConnection() + } + + state = .reconnecting(reconnectTask, reason: reason) + } +} diff --git a/Sources/Realtime/Connection/HeartbeatMonitor.swift b/Sources/Realtime/Connection/HeartbeatMonitor.swift new file mode 100644 index 000000000..d75995cac --- /dev/null +++ b/Sources/Realtime/Connection/HeartbeatMonitor.swift @@ -0,0 +1,113 @@ +// +// HeartbeatMonitor.swift +// Realtime +// +// Created on 17/01/25. +// + +import Foundation +import Helpers + +/// Manages heartbeat send/receive cycle with timeout detection. +/// +/// This actor encapsulates all heartbeat logic, ensuring that heartbeats are sent +/// at regular intervals and timeouts are detected when responses aren't received. +actor HeartbeatMonitor { + // MARK: - Properties + + private let interval: TimeInterval + private let refGenerator: @Sendable () -> String + private let sendHeartbeat: @Sendable (String) async -> Void + private let onTimeout: @Sendable () async -> Void + private let logger: (any SupabaseLogger)? + + private var monitorTask: Task? + private var pendingRef: String? + + // MARK: - Initialization + + init( + interval: TimeInterval, + refGenerator: @escaping @Sendable () -> String, + sendHeartbeat: @escaping @Sendable (String) async -> Void, + onTimeout: @escaping @Sendable () async -> Void, + logger: (any SupabaseLogger)? + ) { + self.interval = interval + self.refGenerator = refGenerator + self.sendHeartbeat = sendHeartbeat + self.onTimeout = onTimeout + self.logger = logger + } + + // MARK: - Public API + + /// Start heartbeat monitoring. + /// + /// Sends heartbeats at regular intervals and detects timeouts when responses + /// aren't received before the next interval. + func start() { + stop() // Cancel any existing monitor + + logger?.debug("Starting heartbeat monitor with interval: \(interval)") + + monitorTask = Task { + while !Task.isCancelled { + do { + try await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000)) + } catch { + // Task cancelled during sleep + break + } + + if Task.isCancelled { break } + + await sendNextHeartbeat() + } + + logger?.debug("Heartbeat monitor stopped") + } + } + + /// Stop heartbeat monitoring. + func stop() { + if monitorTask != nil { + logger?.debug("Stopping heartbeat monitor") + monitorTask?.cancel() + monitorTask = nil + pendingRef = nil + } + } + + /// Called when heartbeat response is received. + /// + /// - Parameter ref: The reference ID from the heartbeat response + func onHeartbeatResponse(ref: String) { + guard let pending = pendingRef, pending == ref else { + logger?.debug("Received heartbeat response with mismatched ref: \(ref)") + return + } + + logger?.debug("Heartbeat acknowledged: \(ref)") + pendingRef = nil + } + + // MARK: - Private Helpers + + private func sendNextHeartbeat() async { + // Check if previous heartbeat was acknowledged + if let pending = pendingRef { + logger?.debug("Heartbeat timeout - previous heartbeat not acknowledged: \(pending)") + pendingRef = nil + await onTimeout() + return + } + + // Send new heartbeat + let ref = refGenerator() + pendingRef = ref + + logger?.debug("Sending heartbeat: \(ref)") + await sendHeartbeat(ref) + } +} diff --git a/Sources/Realtime/Connection/MessageRouter.swift b/Sources/Realtime/Connection/MessageRouter.swift new file mode 100644 index 000000000..8242907a4 --- /dev/null +++ b/Sources/Realtime/Connection/MessageRouter.swift @@ -0,0 +1,94 @@ +// +// MessageRouter.swift +// Realtime +// +// Created on 17/01/25. +// + +import Foundation + +/// Routes incoming messages to appropriate handlers. +/// +/// This actor provides centralized message dispatch, ensuring thread-safe +/// registration and routing of messages to channel and system handlers. +actor MessageRouter { + // MARK: - Type Definitions + + typealias MessageHandler = @Sendable (RealtimeMessageV2) async -> Void + + // MARK: - Properties + + private var channelHandlers: [String: MessageHandler] = [:] + private var systemHandlers: [MessageHandler] = [] + private let logger: (any SupabaseLogger)? + + // MARK: - Initialization + + init(logger: (any SupabaseLogger)?) { + self.logger = logger + } + + // MARK: - Public API + + /// Register handler for a specific channel topic. + /// + /// - Parameters: + /// - topic: The channel topic to handle + /// - handler: The handler to call for messages on this topic + func registerChannel(topic: String, handler: @escaping MessageHandler) { + logger?.debug("Registering message handler for channel: \(topic)") + channelHandlers[topic] = handler + } + + /// Unregister channel handler. + /// + /// - Parameter topic: The channel topic to unregister + func unregisterChannel(topic: String) { + logger?.debug("Unregistering message handler for channel: \(topic)") + channelHandlers[topic] = nil + } + + /// Register system-wide message handler. + /// + /// System handlers are called for every message, regardless of topic. + /// + /// - Parameter handler: The handler to call for all messages + func registerSystemHandler(_ handler: @escaping MessageHandler) { + logger?.debug("Registering system message handler") + systemHandlers.append(handler) + } + + /// Route message to appropriate handlers. + /// + /// This will call all system handlers and the specific channel handler + /// if one is registered for the message's topic. + /// + /// - Parameter message: The message to route + func route(_ message: RealtimeMessageV2) async { + logger?.debug("Routing message - topic: \(message.topic), event: \(message.event)") + + // System handlers always run + for handler in systemHandlers { + await handler(message) + } + + // Route to specific channel if registered + if let handler = channelHandlers[message.topic] { + await handler(message) + } else { + logger?.debug("No handler registered for topic: \(message.topic)") + } + } + + /// Remove all handlers. + func reset() { + logger?.debug("Resetting message router - removing all handlers") + channelHandlers.removeAll() + systemHandlers.removeAll() + } + + /// Get count of registered channel handlers. + var channelCount: Int { + channelHandlers.count + } +} diff --git a/Sources/Realtime/RealtimeChannelV2.swift b/Sources/Realtime/RealtimeChannelV2.swift index a53324c47..72709f1cd 100644 --- a/Sources/Realtime/RealtimeChannelV2.swift +++ b/Sources/Realtime/RealtimeChannelV2.swift @@ -37,6 +37,7 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol { var clientChanges: [PostgresJoinConfig] = [] var joinRef: String? var pushes: [String: PushV2] = [:] + var subscribeTask: Task? } @MainActor @@ -92,7 +93,22 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol { } /// Subscribes to the channel. + @MainActor public func subscribeWithError() async throws { + if let subscribeTask = mutableState.subscribeTask { + try await subscribeTask.value + return + } + + mutableState.subscribeTask = Task { + defer { self.mutableState.subscribeTask = nil } + try await self.performSubscribeWithRetry() + } + + return try await mutableState.subscribeTask!.value + } + + private func performSubscribeWithRetry() async throws { logger?.debug( "Starting subscription to channel '\(topic)' (attempt 1/\(socket.options.maxRetryAttempts))" ) @@ -138,6 +154,14 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol { do { try await _clock.sleep(for: delay) + + // Check if socket is still connected after delay + if socket.status != .connected { + logger?.debug( + "Socket disconnected during retry delay for channel '\(topic)', aborting subscription" + ) + throw CancellationError() + } } catch { // If sleep is cancelled, break out of retry loop logger?.debug("Subscription retry cancelled for channel '\(topic)'") @@ -196,6 +220,12 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol { return } await socket.connect() + + // Verify connection succeeded after await + if socket.status != .connected { + logger?.debug("Socket failed to connect, cannot subscribe to channel \(topic)") + return + } } logger?.debug("Subscribing to channel \(topic)") @@ -234,6 +264,9 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol { logger?.debug("Unsubscribing from channel \(topic)") await push(ChannelEvent.leave) + + // Wait for server confirmation of unsubscription + _ = await statusChange.first { @Sendable in $0 == .unsubscribed } } @available( diff --git a/Sources/Realtime/RealtimeClientV2.swift b/Sources/Realtime/RealtimeClientV2.swift index f99c5173d..44fdbbcf6 100644 --- a/Sources/Realtime/RealtimeClientV2.swift +++ b/Sources/Realtime/RealtimeClientV2.swift @@ -43,6 +43,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { var messageTask: Task? var connectionTask: Task? + var reconnectTask: Task? var channels: [String: RealtimeChannelV2] = [:] var sendBuffer: [@Sendable () -> Void] = [] @@ -170,7 +171,10 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { mutableState.withValue { $0.heartbeatTask?.cancel() $0.messageTask?.cancel() + $0.connectionTask?.cancel() + $0.reconnectTask?.cancel() $0.channels = [:] + $0.conn = nil } } @@ -182,53 +186,77 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { } func connect(reconnect: Bool) async { - if status == .disconnected { - let connectionTask = Task { - if reconnect { - try? await _clock.sleep(for: options.reconnectDelay) + // Check and create connection task atomically to prevent race conditions + let shouldConnect = mutableState.withValue { state -> Bool in + // If already connecting or connected, don't create a new connection task + if status == .connecting || status == .connected { + return false + } - if Task.isCancelled { - options.logger?.debug("Reconnect cancelled, returning") - return - } - } + // If there's already a connection task running, don't create another + if state.connectionTask != nil { + return false + } - if status == .connected { - options.logger?.debug("WebsSocket already connected") + return true + } + + guard shouldConnect else { + // Wait for existing connection to complete + _ = await statusChange.first { @Sendable in $0 == .connected } + return + } + + let connectionTask = Task { + if reconnect { + try? await _clock.sleep(for: options.reconnectDelay) + + if Task.isCancelled { + options.logger?.debug("Reconnect cancelled, returning") return } + } - status = .connecting - - do { - let conn = try await wsTransport( - Self.realtimeWebSocketURL( - baseURL: Self.realtimeBaseURL(url: url), - apikey: options.apikey, - logLevel: options.logLevel - ), - options.headers.dictionary - ) - mutableState.withValue { $0.conn = conn } - onConnected(reconnect: reconnect) - } catch { - onError(error) - } + if status == .connected { + options.logger?.debug("WebsSocket already connected") + return } - mutableState.withValue { - $0.connectionTask = connectionTask + status = .connecting + + do { + let conn = try await wsTransport( + Self.realtimeWebSocketURL( + baseURL: Self.realtimeBaseURL(url: url), + apikey: options.apikey, + logLevel: options.logLevel + ), + options.headers.dictionary + ) + mutableState.withValue { $0.conn = conn } + onConnected(reconnect: reconnect) + } catch { + onError(error) } } + mutableState.withValue { + $0.connectionTask = connectionTask + } + _ = await statusChange.first { @Sendable in $0 == .connected } } private func onConnected(reconnect: Bool) { - status = .connected options.logger?.debug("Connected to realtime WebSocket") + + // Start listeners before setting status to prevent race conditions listenForMessages() startHeartbeating() + + // Now set status to connected + status = .connected + if reconnect { rejoinChannels() } @@ -261,9 +289,14 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { } private func reconnect(disconnectReason: String? = nil) { - Task { - disconnect(reason: disconnectReason) - await connect(reconnect: true) + // Cancel any existing reconnect task and create a new one + mutableState.withValue { state in + state.reconnectTask?.cancel() + + state.reconnectTask = Task { + disconnect(reason: disconnectReason) + await connect(reconnect: true) + } } } @@ -282,6 +315,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { let realtimeTopic = "realtime:\(topic)" if let channel = $0.channels[realtimeTopic] { + self.options.logger?.debug("Reusing existing channel for topic: \(realtimeTopic)") return channel } @@ -325,7 +359,13 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { await channel.unsubscribe() } - if channels.isEmpty { + // Atomically remove channel and check if we should disconnect + let shouldDisconnect = mutableState.withValue { state -> Bool in + state.channels[channel.topic] = nil + return state.channels.isEmpty + } + + if shouldDisconnect { options.logger?.debug("No more subscribed channel in socket") disconnect() } @@ -364,49 +404,57 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { } private func listenForMessages() { - mutableState.withValue { - $0.messageTask?.cancel() - $0.messageTask = Task { [weak self] in - guard let self, let conn = self.conn else { return } - - do { - for await event in conn.events { - if Task.isCancelled { return } - - switch event { - case .binary: - self.options.logger?.error("Unsupported binary event received.") - break - case .text(let text): - let data = Data(text.utf8) - let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data) - await onMessage(message) - - if Task.isCancelled { - return - } - - case .close(let code, let reason): - onClose(code: code, reason: reason) + // Capture conn inside the lock before creating the task + let conn = mutableState.withValue { state -> (any WebSocket)? in + state.messageTask?.cancel() + return state.conn + } + + guard let conn else { return } + + let messageTask = Task { + do { + for await event in conn.events { + if Task.isCancelled { return } + + switch event { + case .binary: + self.options.logger?.error("Unsupported binary event received.") + break + case .text(let text): + let data = Data(text.utf8) + let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data) + await onMessage(message) + + if Task.isCancelled { + return } + + case .close(let code, let reason): + onClose(code: code, reason: reason) } - } catch { - onError(error) } + } catch { + onError(error) } } + + mutableState.withValue { + $0.messageTask = messageTask + } } private func startHeartbeating() { - mutableState.withValue { - $0.heartbeatTask?.cancel() - $0.heartbeatTask = Task { [weak self, options] in + mutableState.withValue { state in + state.heartbeatTask?.cancel() + + state.heartbeatTask = Task { [options] in while !Task.isCancelled { try? await _clock.sleep(for: options.heartbeatInterval) if Task.isCancelled { break } - await self?.sendHeartbeat() + await self.sendHeartbeat() } } } @@ -418,22 +466,27 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { return } - let pendingHeartbeatRef: String? = mutableState.withValue { - if $0.pendingHeartbeatRef != nil { - $0.pendingHeartbeatRef = nil - return nil + // Check if previous heartbeat is still pending (not acknowledged) + let shouldSendHeartbeat = mutableState.withValue { state -> Bool in + if state.pendingHeartbeatRef != nil { + // Previous heartbeat was not acknowledged - this is a timeout + return false } + // No pending heartbeat, we can send a new one let ref = makeRef() - $0.pendingHeartbeatRef = ref - return ref + state.pendingHeartbeatRef = ref + return true } - if let pendingHeartbeatRef { + if shouldSendHeartbeat { + // Get the ref we just set + let heartbeatRef = mutableState.withValue { $0.pendingHeartbeatRef }! + push( RealtimeMessageV2( joinRef: nil, - ref: pendingHeartbeatRef, + ref: heartbeatRef, topic: "phoenix", event: "heartbeat", payload: [:] @@ -442,8 +495,13 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { heartbeatSubject.yield(.sent) await setAuth() } else { - options.logger?.debug("Heartbeat timeout") + // Timeout: previous heartbeat was never acknowledged + options.logger?.debug("Heartbeat timeout - previous heartbeat not acknowledged") heartbeatSubject.yield(.timeout) + + // Clear the pending ref before reconnecting + mutableState.withValue { $0.pendingHeartbeatRef = nil } + reconnect(disconnectReason: "heartbeat timeout") } } @@ -460,8 +518,15 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { mutableState.withValue { $0.ref = 0 $0.messageTask?.cancel() + $0.messageTask = nil $0.heartbeatTask?.cancel() + $0.heartbeatTask = nil $0.connectionTask?.cancel() + $0.connectionTask = nil + $0.reconnectTask?.cancel() + $0.reconnectTask = nil + $0.pendingHeartbeatRef = nil + $0.sendBuffer = [] $0.conn = nil } @@ -485,8 +550,8 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { return } - mutableState.withValue { [token] in - $0.accessToken = token + mutableState.withValue { [tokenToSend] in + $0.accessToken = tokenToSend } for channel in channels.values { @@ -494,7 +559,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { options.logger?.debug("Updating auth token for channel \(channel.topic)") await channel.push( ChannelEvent.accessToken, - payload: ["access_token": token.map { .string($0) } ?? .null] + payload: ["access_token": tokenToSend.map { .string($0) } ?? .null] ) } } diff --git a/Tests/RealtimeTests/AuthTokenManagerTests.swift b/Tests/RealtimeTests/AuthTokenManagerTests.swift new file mode 100644 index 000000000..5c77388f0 --- /dev/null +++ b/Tests/RealtimeTests/AuthTokenManagerTests.swift @@ -0,0 +1,195 @@ +// +// AuthTokenManagerTests.swift +// Realtime Tests +// +// Created on 17/01/25. +// + +import ConcurrencyExtras +import Foundation +import XCTest + +@testable import Realtime + +final class AuthTokenManagerTests: XCTestCase { + var manager: AuthTokenManager! + + override func tearDown() async throws { + manager = nil + try await super.tearDown() + } + + // MARK: - Tests + + func testInitWithToken() async { + manager = AuthTokenManager(initialToken: "initial-token", tokenProvider: nil) + + let token = await manager.getCurrentToken() + + XCTAssertEqual(token, "initial-token") + } + + func testInitWithoutToken() async { + manager = AuthTokenManager(initialToken: nil, tokenProvider: nil) + + let token = await manager.getCurrentToken() + + XCTAssertNil(token) + } + + func testGetCurrentTokenCallsProviderWhenNoToken() async { + let providerCallCount = LockIsolated(0) + + manager = AuthTokenManager( + initialToken: nil, + tokenProvider: { + providerCallCount.withValue { $0 += 1 } + return "provider-token" + } + ) + + let token = await manager.getCurrentToken() + + XCTAssertEqual(token, "provider-token") + XCTAssertEqual(providerCallCount.value, 1) + + // Second call should use cached token, not call provider again + let token2 = await manager.getCurrentToken() + + XCTAssertEqual(token2, "provider-token") + XCTAssertEqual(providerCallCount.value, 1, "Should not call provider again") + } + + func testGetCurrentTokenReturnsInitialTokenWithoutCallingProvider() async { + let providerCallCount = LockIsolated(0) + + manager = AuthTokenManager( + initialToken: "initial-token", + tokenProvider: { + providerCallCount.withValue { $0 += 1 } + return "provider-token" + } + ) + + let token = await manager.getCurrentToken() + + XCTAssertEqual(token, "initial-token") + XCTAssertEqual(providerCallCount.value, 0, "Should not call provider when token exists") + } + + func testUpdateTokenReturnsTrueWhenChanged() async { + manager = AuthTokenManager(initialToken: "old-token", tokenProvider: nil) + + let changed = await manager.updateToken("new-token") + + XCTAssertTrue(changed) + + let token = await manager.getCurrentToken() + XCTAssertEqual(token, "new-token") + } + + func testUpdateTokenReturnsFalseWhenSame() async { + manager = AuthTokenManager(initialToken: "same-token", tokenProvider: nil) + + let changed = await manager.updateToken("same-token") + + XCTAssertFalse(changed) + } + + func testUpdateTokenToNil() async { + manager = AuthTokenManager(initialToken: "some-token", tokenProvider: nil) + + let changed = await manager.updateToken(nil) + + XCTAssertTrue(changed) + + let token = await manager.token + XCTAssertNil(token) + } + + func testRefreshTokenCallsProvider() async { + let providerCallCount = LockIsolated(0) + + manager = AuthTokenManager( + initialToken: "initial-token", + tokenProvider: { + providerCallCount.withValue { + $0 += 1 + return "refreshed-token-\($0)" + } + } + ) + + let token1 = await manager.refreshToken() + + XCTAssertEqual(token1, "refreshed-token-1") + XCTAssertEqual(providerCallCount.value, 1) + + // Refresh again + let token2 = await manager.refreshToken() + + XCTAssertEqual(token2, "refreshed-token-2") + XCTAssertEqual(providerCallCount.value, 2) + } + + func testRefreshTokenWithoutProviderReturnsCurrentToken() async { + manager = AuthTokenManager(initialToken: "current-token", tokenProvider: nil) + + let token = await manager.refreshToken() + + XCTAssertEqual(token, "current-token") + } + + func testRefreshTokenUpdatesInternalToken() async { + manager = AuthTokenManager( + initialToken: "old-token", + tokenProvider: { "new-token" } + ) + + _ = await manager.refreshToken() + + let token = await manager.token + XCTAssertEqual(token, "new-token") + } + + func testProviderThrowingError() async { + manager = AuthTokenManager( + initialToken: nil, + tokenProvider: { + throw NSError(domain: "test", code: 1) + } + ) + + let token = await manager.getCurrentToken() + + XCTAssertNil(token, "Should return nil when provider throws") + } + + func testConcurrentAccess() async { + manager = AuthTokenManager(initialToken: "initial", tokenProvider: nil) + + // Concurrent updates + await withTaskGroup(of: Void.self) { group in + for i in 0..<100 { + group.addTask { + _ = await self.manager.updateToken("token-\(i)") + } + } + + await group.waitForAll() + } + + // Should have some token (race condition, but should not crash) + let token = await manager.token + XCTAssertNotNil(token) + XCTAssertTrue(token!.starts(with: "token-")) + } + + func testTokenPropertyReturnsCurrentValue() async { + manager = AuthTokenManager(initialToken: "test-token", tokenProvider: nil) + + let token = await manager.token + + XCTAssertEqual(token, "test-token") + } +} diff --git a/Tests/RealtimeTests/ConnectionStateMachineTests.swift b/Tests/RealtimeTests/ConnectionStateMachineTests.swift new file mode 100644 index 000000000..5d290f845 --- /dev/null +++ b/Tests/RealtimeTests/ConnectionStateMachineTests.swift @@ -0,0 +1,194 @@ +// +// ConnectionStateMachineTests.swift +// Realtime Tests +// +// Created on 17/01/25. +// + +import Foundation +import XCTest + +@testable import Realtime + +final class ConnectionStateMachineTests: XCTestCase { + var stateMachine: ConnectionStateMachine! + var mockWebSocket: FakeWebSocket! + var connectCallCount = 0 + var lastConnectURL: URL? + var lastConnectHeaders: [String: String]? + + override func setUp() async throws { + try await super.setUp() + connectCallCount = 0 + lastConnectURL = nil + lastConnectHeaders = nil + (mockWebSocket, _) = FakeWebSocket.fakes() + } + + override func tearDown() async throws { + stateMachine = nil + mockWebSocket = nil + try await super.tearDown() + } + + // MARK: - Helper + + func makeStateMachine( + url: URL = URL(string: "ws://localhost")!, + headers: [String: String] = [:], + reconnectDelay: TimeInterval = 0.1 + ) -> ConnectionStateMachine { + ConnectionStateMachine( + transport: { [weak self] url, headers in + self?.connectCallCount += 1 + self?.lastConnectURL = url + self?.lastConnectHeaders = headers + return self!.mockWebSocket + }, + url: url, + headers: headers, + reconnectDelay: reconnectDelay, + logger: nil + ) + } + + // MARK: - Tests + + func testInitialStateIsDisconnected() async { + stateMachine = makeStateMachine() + + let connection = await stateMachine.connection + let isConnected = await stateMachine.isConnected + + XCTAssertNil(connection) + XCTAssertFalse(isConnected) + } + + func testConnectSuccessfully() async throws { + stateMachine = makeStateMachine( + url: URL(string: "ws://example.com")!, + headers: ["Authorization": "Bearer token"] + ) + + let connection = try await stateMachine.connect() + + XCTAssertNotNil(connection) + XCTAssertEqual(connectCallCount, 1) + XCTAssertEqual(lastConnectURL?.absoluteString, "ws://example.com") + XCTAssertEqual(lastConnectHeaders?["Authorization"], "Bearer token") + + let isConnected = await stateMachine.isConnected + XCTAssertTrue(isConnected) + } + + func testMultipleConnectCallsReuseConnection() async throws { + stateMachine = makeStateMachine() + + let connection1 = try await stateMachine.connect() + let connection2 = try await stateMachine.connect() + let connection3 = try await stateMachine.connect() + + XCTAssertEqual(connectCallCount, 1, "Should only connect once") + XCTAssertTrue(connection1 === mockWebSocket) + XCTAssertTrue(connection2 === mockWebSocket) + XCTAssertTrue(connection3 === mockWebSocket) + } + + func testConcurrentConnectCallsCreateSingleConnection() async throws { + stateMachine = makeStateMachine() + + // Simulate concurrent connect calls + async let connection1 = stateMachine.connect() + async let connection2 = stateMachine.connect() + async let connection3 = stateMachine.connect() + + let results = try await [connection1, connection2, connection3] + + XCTAssertEqual(connectCallCount, 1, "Should only connect once despite concurrent calls") + XCTAssertTrue(results.allSatisfy { $0 === mockWebSocket }) + } + + func testDisconnectClosesConnection() async throws { + stateMachine = makeStateMachine() + + _ = try await stateMachine.connect() + XCTAssertFalse(mockWebSocket.isClosed) + + await stateMachine.disconnect(reason: "test disconnect") + + XCTAssertTrue(mockWebSocket.isClosed) + XCTAssertEqual(mockWebSocket.closeReason, "test disconnect") + + let isConnected = await stateMachine.isConnected + XCTAssertFalse(isConnected) + } + + func testDisconnectWhenDisconnectedIsNoop() async { + stateMachine = makeStateMachine() + + // Should not crash + await stateMachine.disconnect() + + let isConnected = await stateMachine.isConnected + XCTAssertFalse(isConnected) + } + + func testHandleErrorTriggersReconnect() async throws { + stateMachine = makeStateMachine(reconnectDelay: 0.05) + + _ = try await stateMachine.connect() + XCTAssertEqual(connectCallCount, 1) + + // Simulate error + await stateMachine.handleError(NSError(domain: "test", code: 1)) + + // Wait for reconnect delay + try await Task.sleep(nanoseconds: 100_000_000) // 0.1 seconds + + // Should have reconnected + XCTAssertEqual(connectCallCount, 2, "Should have reconnected after error") + } + + func testHandleCloseDisconnects() async throws { + stateMachine = makeStateMachine() + + _ = try await stateMachine.connect() + + await stateMachine.handleClose(code: 1000, reason: "normal closure") + + let isConnected = await stateMachine.isConnected + XCTAssertFalse(isConnected) + } + + func testHandleDisconnectedTriggersReconnect() async throws { + stateMachine = makeStateMachine(reconnectDelay: 0.05) + + _ = try await stateMachine.connect() + XCTAssertEqual(connectCallCount, 1) + + await stateMachine.handleDisconnected() + + // Wait for reconnect + try await Task.sleep(nanoseconds: 100_000_000) + + XCTAssertEqual(connectCallCount, 2, "Should have reconnected") + } + + func testDisconnectCancelsReconnection() async throws { + stateMachine = makeStateMachine(reconnectDelay: 0.2) + + _ = try await stateMachine.connect() + + // Trigger reconnection + await stateMachine.handleError(NSError(domain: "test", code: 1)) + + // Immediately disconnect before reconnection completes + await stateMachine.disconnect() + + // Wait longer than reconnect delay + try await Task.sleep(nanoseconds: 300_000_000) + + // Should only have connected once (reconnection was cancelled) + XCTAssertEqual(connectCallCount, 1) + } +} diff --git a/Tests/RealtimeTests/FakeWebSocket.swift b/Tests/RealtimeTests/FakeWebSocket.swift index 357f7ddd5..e7b22ccd8 100644 --- a/Tests/RealtimeTests/FakeWebSocket.swift +++ b/Tests/RealtimeTests/FakeWebSocket.swift @@ -46,6 +46,8 @@ final class FakeWebSocket: WebSocket { s.sentEvents.append(.close(code: code, reason: reason ?? "")) s.isClosed = true + s.closeCode = code + s.closeReason = reason if s.other?.isClosed == false { s.other?._trigger(.close(code: code ?? 1005, reason: reason ?? "")) } diff --git a/Tests/RealtimeTests/HeartbeatMonitorTests.swift b/Tests/RealtimeTests/HeartbeatMonitorTests.swift new file mode 100644 index 000000000..eb7d7f474 --- /dev/null +++ b/Tests/RealtimeTests/HeartbeatMonitorTests.swift @@ -0,0 +1,189 @@ +// +// HeartbeatMonitorTests.swift +// Realtime Tests +// +// Created on 17/01/25. +// + +import Foundation +import XCTest + +@testable import Realtime + +final class HeartbeatMonitorTests: XCTestCase { + var monitor: HeartbeatMonitor! + var sentHeartbeats: [String] = [] + var timeoutCount = 0 + var currentRef = 0 + + override func setUp() async throws { + try await super.setUp() + sentHeartbeats = [] + timeoutCount = 0 + currentRef = 0 + } + + override func tearDown() async throws { + if monitor != nil { + await monitor.stop() + } + monitor = nil + try await super.tearDown() + } + + // MARK: - Helper + + func makeMonitor(interval: TimeInterval = 0.1) -> HeartbeatMonitor { + HeartbeatMonitor( + interval: interval, + refGenerator: { [weak self] in + guard let self else { return "0" } + self.currentRef += 1 + return "\(self.currentRef)" + }, + sendHeartbeat: { [weak self] ref in + self?.sentHeartbeats.append(ref) + }, + onTimeout: { [weak self] in + self?.timeoutCount += 1 + }, + logger: nil + ) + } + + // MARK: - Tests + + func testStartSendsHeartbeatsAtInterval() async throws { + monitor = makeMonitor(interval: 0.05) + + await monitor.start() + + // Wait for a few heartbeats - be generous with timing for CI + try await Task.sleep(nanoseconds: 300_000_000) // 0.3 seconds + + await monitor.stop() + + // Should have sent multiple heartbeats (at least 2 in 0.3s with 0.05s interval) + // Note: Due to Task scheduling delays in CI, we use conservative expectations + // With 0.05s interval, we expect 0.3s / 0.05s = 6 heartbeats ideally, + // but require only 2 to account for scheduling delays + XCTAssertGreaterThanOrEqual(sentHeartbeats.count, 2, "Should send multiple heartbeats") + + // Verify refs increment correctly + for (index, ref) in sentHeartbeats.enumerated() { + XCTAssertEqual(ref, "\(index + 1)", "Refs should increment") + } + } + + func testStopCancelsHeartbeats() async throws { + monitor = makeMonitor(interval: 0.05) + + await monitor.start() + + try await Task.sleep(nanoseconds: 60_000_000) // 0.06 seconds + await monitor.stop() + + let count = sentHeartbeats.count + + // Wait longer + try await Task.sleep(nanoseconds: 100_000_000) + + // Should not have sent more heartbeats after stop + XCTAssertEqual(sentHeartbeats.count, count, "Should not send heartbeats after stop") + } + + func testOnHeartbeatResponseClearsPendingRef() async throws { + monitor = makeMonitor(interval: 0.1) + + await monitor.start() + + // Wait for first heartbeat + try await Task.sleep(nanoseconds: 120_000_000) + + XCTAssertEqual(sentHeartbeats.count, 1) + XCTAssertEqual(timeoutCount, 0) + + // Acknowledge the heartbeat + await monitor.onHeartbeatResponse(ref: "1") + + // Wait for next heartbeat + try await Task.sleep(nanoseconds: 120_000_000) + + await monitor.stop() + + // Should have sent second heartbeat without timeout + XCTAssertEqual(sentHeartbeats.count, 2) + XCTAssertEqual(timeoutCount, 0, "Should not timeout when acknowledged") + } + + func testTimeoutWhenHeartbeatNotAcknowledged() async throws { + monitor = makeMonitor(interval: 0.1) + + await monitor.start() + + // Wait for first heartbeat + try await Task.sleep(nanoseconds: 120_000_000) + + XCTAssertEqual(sentHeartbeats.count, 1) + + // DON'T acknowledge - let it timeout + + // Wait for timeout check + try await Task.sleep(nanoseconds: 120_000_000) + + await monitor.stop() + + // Should have detected timeout and NOT sent second heartbeat + XCTAssertEqual(sentHeartbeats.count, 1, "Should not send new heartbeat on timeout") + XCTAssertEqual(timeoutCount, 1, "Should have detected timeout") + } + + func testMismatchedRefDoesNotClearPending() async throws { + monitor = makeMonitor(interval: 0.1) + + await monitor.start() + + // Wait for first heartbeat + try await Task.sleep(nanoseconds: 120_000_000) + + XCTAssertEqual(sentHeartbeats, ["1"]) + + // Acknowledge with wrong ref + await monitor.onHeartbeatResponse(ref: "999") + + // Wait for next interval + try await Task.sleep(nanoseconds: 120_000_000) + + await monitor.stop() + + // Should timeout because correct ref was not acknowledged + XCTAssertEqual(timeoutCount, 1, "Should timeout with mismatched ref") + } + + func testRestartCreatesNewMonitor() async throws { + monitor = makeMonitor(interval: 0.05) + + await monitor.start() + try await Task.sleep(nanoseconds: 60_000_000) + let firstCount = sentHeartbeats.count + + // Restart + await monitor.start() + + // Old task should be cancelled, new one started + try await Task.sleep(nanoseconds: 60_000_000) + + await monitor.stop() + + // Should have continued sending + XCTAssertGreaterThan(sentHeartbeats.count, firstCount) + } + + func testStopWhenNotStartedIsNoop() async { + monitor = makeMonitor() + + // Should not crash + await monitor.stop() + await monitor.stop() + } +} diff --git a/Tests/RealtimeTests/MessageRouterTests.swift b/Tests/RealtimeTests/MessageRouterTests.swift new file mode 100644 index 000000000..33ae4b8a0 --- /dev/null +++ b/Tests/RealtimeTests/MessageRouterTests.swift @@ -0,0 +1,256 @@ +// +// MessageRouterTests.swift +// Realtime Tests +// +// Created on 17/01/25. +// + +import ConcurrencyExtras +import Foundation +import XCTest + +@testable import Realtime + +final class MessageRouterTests: XCTestCase { + var router: MessageRouter! + var receivedMessages: [RealtimeMessageV2] = [] + + override func setUp() async throws { + try await super.setUp() + router = MessageRouter(logger: nil) + receivedMessages = [] + } + + override func tearDown() async throws { + router = nil + receivedMessages = [] + try await super.tearDown() + } + + // MARK: - Helper + + func makeMessage(topic: String, event: String) -> RealtimeMessageV2 { + RealtimeMessageV2( + joinRef: nil, + ref: nil, + topic: topic, + event: event, + payload: [:] + ) + } + + // MARK: - Tests + + func testRouteToRegisteredChannel() async { + let channelAMessages = LockIsolated([RealtimeMessageV2]()) + let channelBMessages = LockIsolated([RealtimeMessageV2]()) + + await router.registerChannel(topic: "channel-a") { message in + channelAMessages.withValue { $0.append(message) } + } + + await router.registerChannel(topic: "channel-b") { message in + channelBMessages.withValue { $0.append(message) } + } + + let messageA = makeMessage(topic: "channel-a", event: "test") + let messageB = makeMessage(topic: "channel-b", event: "test") + + await router.route(messageA) + await router.route(messageB) + + XCTAssertEqual(channelAMessages.value.count, 1) + XCTAssertEqual(channelAMessages.value.first?.topic, "channel-a") + + XCTAssertEqual(channelBMessages.value.count, 1) + XCTAssertEqual(channelBMessages.value.first?.topic, "channel-b") + } + + func testRouteToUnregisteredChannelDoesNotCrash() async { + let message = makeMessage(topic: "unknown-channel", event: "test") + + // Should not crash + await router.route(message) + } + + func testSystemHandlerReceivesAllMessages() async { + let systemMessages = LockIsolated([RealtimeMessageV2]()) + + await router.registerSystemHandler { message in + systemMessages.withValue { $0.append(message) } + } + + let message1 = makeMessage(topic: "channel-a", event: "event1") + let message2 = makeMessage(topic: "channel-b", event: "event2") + let message3 = makeMessage(topic: "channel-c", event: "event3") + + await router.route(message1) + await router.route(message2) + await router.route(message3) + + XCTAssertEqual(systemMessages.value.count, 3) + XCTAssertEqual(systemMessages.value[0].topic, "channel-a") + XCTAssertEqual(systemMessages.value[1].topic, "channel-b") + XCTAssertEqual(systemMessages.value[2].topic, "channel-c") + } + + func testBothSystemAndChannelHandlersReceiveMessage() async { + let systemMessages = LockIsolated([RealtimeMessageV2]()) + let channelMessages = LockIsolated([RealtimeMessageV2]()) + + await router.registerSystemHandler { message in + systemMessages.withValue { $0.append(message) } + } + + await router.registerChannel(topic: "test-channel") { message in + channelMessages.withValue { $0.append(message) } + } + + let message = makeMessage(topic: "test-channel", event: "test") + await router.route(message) + + XCTAssertEqual(systemMessages.value.count, 1) + XCTAssertEqual(channelMessages.value.count, 1) + } + + func testUnregisterChannelStopsRoutingToIt() async { + let channelMessages = LockIsolated([RealtimeMessageV2]()) + + await router.registerChannel(topic: "test-channel") { message in + channelMessages.withValue { $0.append(message) } + } + + let message1 = makeMessage(topic: "test-channel", event: "test1") + await router.route(message1) + + XCTAssertEqual(channelMessages.value.count, 1) + + // Unregister + await router.unregisterChannel(topic: "test-channel") + + let message2 = makeMessage(topic: "test-channel", event: "test2") + await router.route(message2) + + // Should still be 1 (not routed after unregister) + XCTAssertEqual(channelMessages.value.count, 1) + } + + func testReregisterChannelReplacesHandler() async { + let handler1Messages = LockIsolated([RealtimeMessageV2]()) + let handler2Messages = LockIsolated([RealtimeMessageV2]()) + + await router.registerChannel(topic: "test-channel") { message in + handler1Messages.withValue { $0.append(message) } + } + + let message1 = makeMessage(topic: "test-channel", event: "test1") + await router.route(message1) + + XCTAssertEqual(handler1Messages.value.count, 1) + XCTAssertEqual(handler2Messages.value.count, 0) + + // Re-register with new handler + await router.registerChannel(topic: "test-channel") { message in + handler2Messages.withValue { $0.append(message) } + } + + let message2 = makeMessage(topic: "test-channel", event: "test2") + await router.route(message2) + + // First handler should not receive second message + XCTAssertEqual(handler1Messages.value.count, 1) + // Second handler should receive it + XCTAssertEqual(handler2Messages.value.count, 1) + } + + func testResetRemovesAllHandlers() async { + let channelMessages = LockIsolated([RealtimeMessageV2]()) + let systemMessages = LockIsolated([RealtimeMessageV2]()) + + await router.registerChannel(topic: "channel-a") { message in + channelMessages.withValue { $0.append(message) } + } + + await router.registerSystemHandler { message in + systemMessages.withValue { $0.append(message) } + } + + let message1 = makeMessage(topic: "channel-a", event: "test1") + await router.route(message1) + + XCTAssertEqual(channelMessages.count, 1) + XCTAssertEqual(systemMessages.count, 1) + + // Reset + await router.reset() + + let message2 = makeMessage(topic: "channel-a", event: "test2") + await router.route(message2) + + // No more messages after reset + XCTAssertEqual(channelMessages.value.count, 1) + XCTAssertEqual(systemMessages.value.count, 1) + } + + func testChannelCountReflectsRegistrations() async { + var count = await router.channelCount + XCTAssertEqual(count, 0) + + await router.registerChannel(topic: "channel-a") { _ in } + count = await router.channelCount + XCTAssertEqual(count, 1) + + await router.registerChannel(topic: "channel-b") { _ in } + count = await router.channelCount + XCTAssertEqual(count, 2) + + await router.unregisterChannel(topic: "channel-a") + count = await router.channelCount + XCTAssertEqual(count, 1) + + await router.reset() + count = await router.channelCount + XCTAssertEqual(count, 0) + } + + func testMultipleSystemHandlers() async { + let system1Messages = LockIsolated([RealtimeMessageV2]()) + let system2Messages = LockIsolated([RealtimeMessageV2]()) + + await router.registerSystemHandler { message in + system1Messages.withValue { $0.append(message) } + } + + await router.registerSystemHandler { message in + system2Messages.withValue { $0.append(message) } + } + + let message = makeMessage(topic: "test", event: "test") + await router.route(message) + + XCTAssertEqual(system1Messages.value.count, 1) + XCTAssertEqual(system2Messages.value.count, 1) + } + + func testConcurrentRouting() async { + let receivedCount = LockIsolated(0) + + await router.registerChannel(topic: "test-channel") { _ in + receivedCount.withValue { $0 += 1 } + } + + // Route messages concurrently + await withTaskGroup(of: Void.self) { group in + for i in 0..<100 { + group.addTask { + let message = self.makeMessage(topic: "test-channel", event: "test-\(i)") + await self.router.route(message) + } + } + + await group.waitForAll() + } + + XCTAssertEqual(receivedCount.value, 100, "Should receive all messages") + } +} diff --git a/Tests/RealtimeTests/RealtimeChannelTests.swift b/Tests/RealtimeTests/RealtimeChannelTests.swift index 4fdbaa67d..d8ea58be4 100644 --- a/Tests/RealtimeTests/RealtimeChannelTests.swift +++ b/Tests/RealtimeTests/RealtimeChannelTests.swift @@ -5,6 +5,7 @@ // Created by Guilherme Souza on 09/09/24. // +import ConcurrencyExtras import InlineSnapshotTesting import TestHelpers import XCTest @@ -196,6 +197,58 @@ final class RealtimeChannelTests: XCTestCase { // The subscription is still in progress when we clean up } + func testConcurrentSubscribeRunsSingleOperation() async throws { + let socket = ConcurrentSubscribeRealtimeClient() + let channel = RealtimeChannelV2( + topic: "test-topic", + config: RealtimeChannelConfig( + broadcast: BroadcastJoinConfig(), + presence: PresenceJoinConfig(), + isPrivate: false + ), + socket: socket, + logger: nil + ) + + async let firstSubscribe = channel.subscribeWithError() + async let secondSubscribe = channel.subscribeWithError() + + try await waitForJoin( + on: socket, + expectedCount: 1 + ) + XCTAssertEqual(socket.joinPushCount, 1) + + await channel.onMessage( + RealtimeMessageV2( + joinRef: nil, + ref: nil, + topic: channel.topic, + event: ChannelEvent.system, + payload: ["status": "ok"] + ) + ) + + try await firstSubscribe + try await secondSubscribe + + XCTAssertEqual(socket.joinPushCount, 1) + } + + private func waitForJoin( + on socket: ConcurrentSubscribeRealtimeClient, + expectedCount: Int + ) async throws { + for _ in 0..<50 { + if socket.joinPushCount == expectedCount { + return + } + try await Task.sleep(nanoseconds: 10_000_000) + } + + XCTFail("Timed out waiting for join push") + } + func testHttpSendThrowsWhenAccessTokenIsMissing() async { let httpClient = HTTPClientMock() let (client, _) = FakeWebSocket.fakes() @@ -439,7 +492,9 @@ final class RealtimeChannelTests: XCTestCase { XCTFail("Expected httpSend to throw an error on 503 status") } catch { // Should fall back to localized status text - XCTAssertTrue(error.localizedDescription.contains("503") || error.localizedDescription.contains("unavailable")) + XCTAssertTrue( + error.localizedDescription.contains("503") + || error.localizedDescription.contains("unavailable")) } } } @@ -455,3 +510,50 @@ private struct BroadcastPayload: Decodable { let `private`: Bool } } + +private final class ConcurrentSubscribeRealtimeClient: RealtimeClientProtocol, @unchecked Sendable { + private let _pushedMessages = LockIsolated<[RealtimeMessageV2]>([]) + private let _status = LockIsolated(.connected) + + let options: RealtimeClientOptions + let http: any HTTPClientType + let broadcastURL = URL(string: "https://localhost:54321/realtime/v1/api/broadcast")! + + init() { + self.options = RealtimeClientOptions( + headers: ["apikey": "test-key"], + timeoutInterval: 5.0 + ) + self.http = HTTPClientMock() + } + + var status: RealtimeClientStatus { + _status.value + } + + var pushedMessages: [RealtimeMessageV2] { + _pushedMessages.value + } + + var joinPushCount: Int { + pushedMessages.filter { $0.event == ChannelEvent.join }.count + } + + func connect() async { + _status.setValue(.connected) + } + + func push(_ message: RealtimeMessageV2) { + _pushedMessages.withValue { $0.append(message) } + } + + func _getAccessToken() async -> String? { + nil + } + + func makeRef() -> String { + UUID().uuidString + } + + func _remove(_: any RealtimeChannelProtocol) {} +} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 8b1f088ae..62543de9d 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -754,8 +754,8 @@ final class RealtimeTests: XCTestCase { await Task.megaYield() - // Verify that the message task was cancelled - XCTAssertTrue(sut.mutableState.messageTask?.isCancelled ?? false) + // Verify that the message task was cancelled and cleaned up + XCTAssertNil(sut.mutableState.messageTask, "Message task should be nil after disconnect") } func testMultipleReconnectionsHandleTaskLifecycleCorrectly() async { diff --git a/docs/realtime-refactoring-proposal.md b/docs/realtime-refactoring-proposal.md new file mode 100644 index 000000000..97279ca38 --- /dev/null +++ b/docs/realtime-refactoring-proposal.md @@ -0,0 +1,1194 @@ +# Realtime Module Refactoring Proposal + +**Date:** 2025-01-17 +**Author:** AI-assisted analysis +**Status:** Proposed + +## Executive Summary + +This document proposes a comprehensive refactoring of the Realtime module to address maintainability, reliability, and testability concerns. The refactoring uses actor-based state machines and clear separation of concerns to eliminate race conditions and reduce complexity. + +**Key Metrics:** +- Current LOC: ~1,670 (3 main files) +- Proposed reduction: ~40% through better organization +- Estimated effort: 11-17 days +- Risk level: Low-Medium (backward compatible) + +--- + +## Current Architecture Pain Points + +### 1. **God Object Anti-Pattern** + +**RealtimeClientV2** (678 LOC) handles too many responsibilities: +- WebSocket connection management +- Heartbeat logic +- Message routing +- Channel management +- Auth token management +- Reconnection logic +- URL building +- Message buffering + +**RealtimeChannelV2** (777 LOC) also does too much: +- Subscription management with retry logic +- Message handling for multiple event types +- Callback management delegation +- HTTP fallback for broadcasts +- Presence tracking +- Postgres changes filtering +- Push message queuing + +**Impact:** +- Hard to understand code flow +- Difficult to test individual components +- Changes in one area affect unrelated functionality +- High bug density + +### 2. **State Management Issues** + +**Problems:** +- Large mutable state structs with many fields +- Lock contention from single `LockIsolated` wrapping all state +- Difficult to reason about state transitions +- No clear state machine for connection/subscription states + +**Example from RealtimeClientV2.swift:** +```swift +struct MutableState { + var accessToken: String? + var ref = 0 + var pendingHeartbeatRef: String? + var heartbeatTask: Task? + var messageTask: Task? + var connectionTask: Task? + var reconnectTask: Task? + var channels: [String: RealtimeChannelV2] = [:] + var sendBuffer: [@Sendable () -> Void] = [] + var conn: (any WebSocket)? +} +``` + +**Issues:** +- All state locked together (coarse-grained locking) +- No validation of state transitions +- Easy to have invalid combinations (e.g., `connectionTask` + `reconnectTask`) + +### 3. **Tight Coupling** + +**Problems:** +- Channel directly references socket +- Socket directly manages channels +- Circular dependencies make testing hard +- Hard to mock or substitute components + +**Impact:** +- Cannot test components in isolation +- Changes ripple across boundaries +- Difficult to add alternative implementations + +### 4. **Missing Abstractions** + +**Problems:** +- Connection lifecycle scattered across multiple methods +- No clear separation between transport and application logic +- Heartbeat logic mixed with connection management +- Message encoding/decoding inline with business logic + +**Example:** Heartbeat logic is spread across: +- `startHeartbeating()` - starts the task +- `sendHeartbeat()` - sends and checks timeout +- `onMessage()` - clears pending ref +- `disconnect()` - cancels task + +### 5. **Task Management Complexity** + +**Problems:** +- Multiple long-running tasks tracked in mutable state +- Complex cancellation dependencies +- Difficult to test task lifecycle +- Race conditions during task creation/cancellation + +**Recent bugs fixed:** +- Multiple connection tasks created simultaneously +- Reconnect tasks not cancelled +- Message tasks accessing nil connections +- Weak self causing silent failures + +--- + +## Proposed Refactoring + +### **Architecture: Layered + Actor-Based State Machines** + +``` +┌─────────────────────────────────────────────────────────┐ +│ RealtimeClient │ +│ (Public API & Orchestration) │ +└─────────────────┬───────────────────────────────────────┘ + │ + ┌─────────────┴──────────────┬──────────────────────┐ + │ │ │ +┌───▼────────────┐ ┌───────────▼─────────┐ ┌────────▼─────────┐ +│ ConnectionMgr │ │ ChannelRegistry │ │ AuthManager │ +│ (State Machine)│ │ (Channel Lookup) │ │ (Token Mgmt) │ +└───┬────────────┘ └─────────────────────┘ └──────────────────┘ + │ +┌───▼────────────┐ +│ WebSocketConn │ +│ (Transport) │ +└───┬────────────┘ + │ +┌───▼────────────┐ +│ MessageRouter │ +│ (Dispatch) │ +└────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ RealtimeChannel │ +│ (Channel-specific Logic) │ +└─────────────────┬───────────────────────────────────────┘ + │ + ┌─────────────┴──────────────┬──────────────────────┐ + │ │ │ +┌───▼────────────┐ ┌───────────▼─────────┐ ┌────────▼─────────┐ +│SubscriptionMgr│ │ CallbackManager │ │ EventHandler │ +│(State Machine) │ │ (Listener Registry)│ │ (Type Routing) │ +└────────────────┘ └─────────────────────┘ └──────────────────┘ +``` + +**Key Principles:** +1. **Single Responsibility** - Each component has one clear purpose +2. **Actor Isolation** - State machines use Swift actors for thread safety +3. **Dependency Injection** - Protocol-based dependencies for testability +4. **Immutable State Transitions** - State machines enforce valid transitions +5. **Clear Boundaries** - Well-defined interfaces between layers + +--- + +## Phase 1: Extract Core Components (Low Risk) + +### 1.1 **ConnectionStateMachine** + +```swift +/// Manages WebSocket connection lifecycle with clear state transitions +actor ConnectionStateMachine { + enum State: Sendable { + case disconnected + case connecting(Task) + case connected(WebSocketConnection) + case reconnecting(Task, reason: String) + } + + private(set) var state: State = .disconnected + private let transport: WebSocketTransport + private let options: ConnectionOptions + + init(transport: WebSocketTransport, options: ConnectionOptions) { + self.transport = transport + self.options = options + } + + /// Connect to WebSocket. Returns existing connection if already connected. + func connect() async throws -> WebSocketConnection { + switch state { + case .connected(let conn): + return conn + case .connecting(let task): + // Wait for existing connection attempt + await task.value + return try await connect() + case .disconnected, .reconnecting: + let task = Task { + let conn = try await transport.connect( + to: options.url, + headers: options.headers + ) + state = .connected(conn) + } + state = .connecting(task) + try await task.value + return try await connect() + } + } + + /// Disconnect and clean up resources + func disconnect(reason: String?) { + switch state { + case .connected(let conn): + conn.close(reason: reason) + state = .disconnected + case .connecting(let task), .reconnecting(let task, _): + task.cancel() + state = .disconnected + case .disconnected: + break + } + } + + /// Handle connection error and initiate reconnect + func handleError(_ error: Error) { + guard case .connected = state else { return } + let task = Task { + try? await Task.sleep(for: options.reconnectDelay) + _ = try? await connect() + } + state = .reconnecting(task, reason: error.localizedDescription) + } + + /// Handle connection close + func handleClose(code: Int?, reason: String?) { + disconnect(reason: reason) + } +} +``` + +**Benefits:** +- ✅ Impossible to create multiple connections simultaneously +- ✅ Clear, validatable state transitions +- ✅ Actor isolation prevents race conditions +- ✅ Easy to test all state transitions +- ✅ Self-documenting code + +**Migration:** +- Replace `connect(reconnect:)` logic in `RealtimeClientV2` +- Eliminates need for `connectionTask` and `reconnectTask` in `MutableState` +- Fixes connection race condition permanently + +--- + +### 1.2 **HeartbeatMonitor** + +```swift +/// Manages heartbeat send/receive cycle with timeout detection +actor HeartbeatMonitor { + private let interval: Duration + private let onTimeout: @Sendable () async -> Void + private let sendHeartbeat: @Sendable (String) async -> Void + + private var monitorTask: Task? + private var pendingRef: String? + private var refGenerator: () -> String + + init( + interval: Duration, + refGenerator: @escaping () -> String, + sendHeartbeat: @escaping @Sendable (String) async -> Void, + onTimeout: @escaping @Sendable () async -> Void + ) { + self.interval = interval + self.refGenerator = refGenerator + self.sendHeartbeat = sendHeartbeat + self.onTimeout = onTimeout + } + + /// Start heartbeat monitoring + func start() { + stop() // Cancel any existing monitor + + monitorTask = Task { + while !Task.isCancelled { + try? await Task.sleep(for: interval) + + if Task.isCancelled { break } + + await sendNextHeartbeat() + } + } + } + + /// Stop heartbeat monitoring + func stop() { + monitorTask?.cancel() + monitorTask = nil + pendingRef = nil + } + + /// Called when heartbeat response is received + func onHeartbeatResponse(ref: String) { + guard pendingRef == ref else { return } + pendingRef = nil + } + + private func sendNextHeartbeat() async { + // Check if previous heartbeat was acknowledged + if pendingRef != nil { + // Timeout: previous heartbeat not acknowledged + pendingRef = nil + await onTimeout() + return + } + + // Send new heartbeat + let ref = refGenerator() + pendingRef = ref + await sendHeartbeat(ref) + } +} +``` + +**Benefits:** +- ✅ All heartbeat logic in one place +- ✅ Clear timeout detection +- ✅ Easy to test timeout scenarios +- ✅ No shared mutable state +- ✅ Simple to verify correctness + +**Migration:** +- Replace `startHeartbeating()` and `sendHeartbeat()` in `RealtimeClientV2` +- Eliminates `heartbeatTask` and `pendingHeartbeatRef` from `MutableState` +- Fixes heartbeat timeout logic permanently + +--- + +### 1.3 **MessageRouter** + +```swift +/// Routes incoming messages to appropriate handlers +actor MessageRouter { + typealias MessageHandler = @Sendable (RealtimeMessageV2) async -> Void + + private var channelHandlers: [String: MessageHandler] = [:] + private var systemHandlers: [MessageHandler] = [] + + /// Register handler for a specific channel topic + func registerChannel(topic: String, handler: @escaping MessageHandler) { + channelHandlers[topic] = handler + } + + /// Unregister channel handler + func unregisterChannel(topic: String) { + channelHandlers[topic] = nil + } + + /// Register system-wide message handler + func registerSystemHandler(_ handler: @escaping MessageHandler) { + systemHandlers.append(handler) + } + + /// Route message to appropriate handlers + func route(_ message: RealtimeMessageV2) async { + // System handlers always run + for handler in systemHandlers { + await handler(message) + } + + // Route to specific channel if registered + if let handler = channelHandlers[message.topic] { + await handler(message) + } + } + + /// Remove all handlers + func reset() { + channelHandlers.removeAll() + systemHandlers.removeAll() + } +} +``` + +**Benefits:** +- ✅ Centralized message dispatch +- ✅ Type-safe routing +- ✅ Easy to add middleware/logging +- ✅ Clear registration/unregistration +- ✅ Simple to test routing logic + +**Migration:** +- Replace `onMessage()` in `RealtimeClientV2` +- Channels register themselves on subscribe +- Clean separation of concerns + +--- + +### 1.4 **AuthTokenManager** + +```swift +/// Manages authentication token lifecycle and distribution +actor AuthTokenManager { + private var currentToken: String? + private let tokenProvider: (@Sendable () async throws -> String?)? + + init( + initialToken: String?, + tokenProvider: (@Sendable () async throws -> String?)? + ) { + self.currentToken = initialToken + self.tokenProvider = tokenProvider + } + + /// Get current token, calling provider if needed + func getCurrentToken() async -> String? { + if let token = currentToken { + return token + } + + // Try to get from provider + if let provider = tokenProvider { + let token = try? await provider() + currentToken = token + return token + } + + return nil + } + + /// Update token and return if it changed + func updateToken(_ token: String?) async -> Bool { + guard token != currentToken else { + return false + } + currentToken = token + return true + } + + /// Refresh token from provider if available + func refreshToken() async -> String? { + guard let provider = tokenProvider else { + return currentToken + } + + let token = try? await provider() + currentToken = token + return token + } +} +``` + +**Benefits:** +- ✅ Single source of truth for auth +- ✅ Handles callback vs direct token correctly +- ✅ Clear token refresh logic +- ✅ Easy to test token scenarios +- ✅ No more token assignment bugs + +**Migration:** +- Replace auth token logic in `RealtimeClientV2` +- Fixes `setAuth()` token assignment bug permanently +- Simplifies token distribution to channels + +--- + +## Phase 2: Refactor Channel Subscription (Medium Risk) + +### 2.1 **SubscriptionStateMachine** + +```swift +/// Manages channel subscription lifecycle with retry logic +actor SubscriptionStateMachine { + enum State: Sendable { + case unsubscribed + case subscribing(attempt: Int, task: Task) + case subscribed(joinRef: String) + case unsubscribing + } + + private(set) var state: State = .unsubscribed + private let maxAttempts: Int + private let timeout: Duration + + init(maxAttempts: Int, timeout: Duration) { + self.maxAttempts = maxAttempts + self.timeout = timeout + } + + /// Subscribe with automatic retry and exponential backoff + func subscribe( + executor: SubscriptionExecutor + ) async throws { + guard case .unsubscribed = state else { + throw RealtimeError("Cannot subscribe in current state: \(state)") + } + + var attempt = 0 + + while attempt < maxAttempts { + attempt += 1 + + let task = Task { + try await withTimeout(interval: timeout) { + try await executor.execute() + } + } + + state = .subscribing(attempt: attempt, task: task) + + do { + try await task.value + // Success - executor should set state to .subscribed + return + } catch is TimeoutError { + if attempt < maxAttempts { + let delay = calculateBackoff(attempt: attempt) + try await Task.sleep(for: delay) + + // Check if still valid to retry + guard case .subscribing = state else { + throw CancellationError() + } + } + } catch { + state = .unsubscribed + throw error + } + } + + state = .unsubscribed + throw RealtimeError.maxRetryAttemptsReached + } + + /// Mark subscription as successful + func markSubscribed(joinRef: String) { + state = .subscribed(joinRef: joinRef) + } + + /// Unsubscribe from channel + func unsubscribe() async { + switch state { + case .subscribed, .subscribing: + state = .unsubscribing + default: + state = .unsubscribed + } + } + + /// Mark unsubscription as complete + func markUnsubscribed() { + state = .unsubscribed + } + + private func calculateBackoff(attempt: Int) -> Duration { + let baseDelay: Double = 1.0 + let maxDelay: Double = 30.0 + let backoffMultiplier: Double = 2.0 + + let exponentialDelay = baseDelay * pow(backoffMultiplier, Double(attempt - 1)) + let cappedDelay = min(exponentialDelay, maxDelay) + + // Add jitter (±25%) + let jitterRange = cappedDelay * 0.25 + let jitter = Double.random(in: -jitterRange...jitterRange) + + return .seconds(max(0.1, cappedDelay + jitter)) + } +} + +/// Protocol for executing subscription logic +protocol SubscriptionExecutor { + func execute() async throws +} +``` + +**Benefits:** +- ✅ All retry logic isolated +- ✅ Cannot be in invalid state +- ✅ Easy to test exponential backoff +- ✅ Clear error handling +- ✅ Composable retry strategies + +**Migration:** +- Replace `subscribeWithError()` in `RealtimeChannelV2` +- Eliminates complex retry state management +- Clearer separation of concerns + +--- + +### 2.2 **EventHandlerRegistry** + +```swift +/// Type-safe event handler registration and dispatch +final class EventHandlerRegistry: Sendable { + private struct Handler: Sendable { + let id: UUID + let callback: @Sendable (Any) -> Void + } + + private let handlers = LockIsolated<[ObjectIdentifier: [Handler]]>([:]) + + /// Register handler for specific event type + func on( + _ eventType: T.Type, + handler: @escaping @Sendable (T) -> Void + ) -> Subscription { + let id = UUID() + let typeId = ObjectIdentifier(T.self) + + let wrappedHandler = Handler(id: id) { value in + if let typedValue = value as? T { + handler(typedValue) + } + } + + handlers.withValue { handlers in + handlers[typeId, default: []].append(wrappedHandler) + } + + return Subscription { [weak handlers] in + handlers?.withValue { handlers in + handlers[typeId]?.removeAll { $0.id == id } + } + } + } + + /// Trigger event to all registered handlers + func trigger(_ event: T) { + let typeId = ObjectIdentifier(T.self) + + let matchingHandlers = handlers.withValue { handlers in + handlers[typeId] ?? [] + } + + for handler in matchingHandlers { + handler.callback(event) + } + } + + /// Remove all handlers + func removeAll() { + handlers.withValue { $0.removeAll() } + } +} + +/// Represents an active subscription that can be cancelled +public struct Subscription: Sendable { + private let cancellation: @Sendable () -> Void + + init(cancellation: @escaping @Sendable () -> Void) { + self.cancellation = cancellation + } + + public func cancel() { + cancellation() + } +} +``` + +**Benefits:** +- ✅ Type-safe event handling +- ✅ Replaces CallbackManager with simpler API +- ✅ Automatic cleanup via Subscription +- ✅ Easy to test event dispatch +- ✅ Composable subscriptions + +**Migration:** +- Replace `CallbackManager` in `RealtimeChannelV2` +- Cleaner API for event handlers +- Better type safety + +--- + +## Phase 3: Improve Testability (Low Risk) + +### 3.1 **Protocol-based Dependencies** + +```swift +/// WebSocket transport abstraction +protocol WebSocketTransport: Sendable { + func connect(to url: URL, headers: [String: String]) async throws -> WebSocketConnection +} + +/// WebSocket connection abstraction +protocol WebSocketConnection: Sendable { + var events: AsyncStream { get } + func send(_ message: String) + func close(code: Int?, reason: String?) +} + +/// Logging abstraction +protocol RealtimeLogger: Sendable { + func debug(_ message: String) + func error(_ message: String) +} + +/// Clock abstraction for testing +protocol Clock: Sendable { + func sleep(for duration: Duration) async throws + func now() -> Date +} + +/// Production clock implementation +struct SystemClock: Clock { + func sleep(for duration: Duration) async throws { + try await Task.sleep(for: duration) + } + + func now() -> Date { + Date() + } +} + +/// Test clock for deterministic timing +final class TestClock: Clock { + private var currentTime = Date() + + func sleep(for duration: Duration) async throws { + currentTime = currentTime.addingTimeInterval(duration.timeInterval) + } + + func now() -> Date { + currentTime + } + + func advance(by duration: Duration) { + currentTime = currentTime.addingTimeInterval(duration.timeInterval) + } +} +``` + +**Benefits:** +- ✅ Easy to mock in tests +- ✅ Dependency injection +- ✅ Platform-agnostic +- ✅ Deterministic testing +- ✅ No need for XCTest tricks + +**Migration:** +- Add protocols for existing dependencies +- Use in new components +- Gradually adopt in existing code + +--- + +## Phase 4: File Organization + +### **Proposed Structure** + +``` +Sources/Realtime/ +├── Client/ +│ ├── RealtimeClient.swift (Public API, ~150 LOC) +│ ├── RealtimeClientOptions.swift (~50 LOC) +│ ├── ConnectionStateMachine.swift (~100 LOC) +│ └── ChannelRegistry.swift (~80 LOC) +│ +├── Channel/ +│ ├── RealtimeChannel.swift (Public API, ~200 LOC) +│ ├── RealtimeChannelConfig.swift (~50 LOC) +│ ├── SubscriptionStateMachine.swift (~120 LOC) +│ └── EventHandlerRegistry.swift (~100 LOC) +│ +├── Connection/ +│ ├── WebSocketConnection.swift (Protocol, ~30 LOC) +│ ├── URLSessionWebSocket.swift (Implementation, ~100 LOC) +│ ├── HeartbeatMonitor.swift (~80 LOC) +│ └── MessageRouter.swift (~60 LOC) +│ +├── Auth/ +│ └── AuthTokenManager.swift (~80 LOC) +│ +├── Messages/ +│ ├── RealtimeMessage.swift (~100 LOC) +│ ├── MessageEncoder.swift (~50 LOC) +│ └── MessageDecoder.swift (~50 LOC) +│ +├── Events/ +│ ├── PostgresAction.swift (Existing) +│ ├── PresenceAction.swift (Existing) +│ └── BroadcastEvent.swift (~50 LOC) +│ +├── Support/ +│ ├── Types.swift (Existing) +│ ├── Errors.swift (Existing) +│ └── Protocols.swift (~100 LOC) +│ +└── Deprecated/ + └── ... (Existing deprecated code) +``` + +**Total estimated LOC: ~1,500** (vs current ~1,670) + +**Benefits:** +- ✅ Logical grouping by responsibility +- ✅ Easy to navigate +- ✅ Clear module boundaries +- ✅ Smaller, focused files +- ✅ Better IntelliSense + +--- + +## Refactored Public API (Maintains Backward Compatibility) + +### **RealtimeClient** + +```swift +public final class RealtimeClient: Sendable { + // Internal actors - complexity hidden + private let connectionMgr: ConnectionStateMachine + private let channelRegistry: ChannelRegistry + private let authMgr: AuthTokenManager + private let router: MessageRouter + private let heartbeat: HeartbeatMonitor + + // Public API - UNCHANGED + public var status: RealtimeClientStatus { ... } + public var statusChange: AsyncStream { ... } + public var heartbeat: AsyncStream { ... } + public var channels: [String: RealtimeChannel] { ... } + + public init(url: URL, options: RealtimeClientOptions) { ... } + + public func connect() async { ... } + public func disconnect(code: Int?, reason: String?) { ... } + + public func channel( + _ topic: String, + options: @Sendable (inout RealtimeChannelConfig) -> Void = { _ in } + ) -> RealtimeChannel { ... } + + public func removeChannel(_ channel: RealtimeChannel) async { ... } + public func removeAllChannels() async { ... } + + public func setAuth(_ token: String?) async { ... } + + public func onStatusChange( + _ listener: @escaping @Sendable (RealtimeClientStatus) -> Void + ) -> RealtimeSubscription { ... } + + public func onHeartbeat( + _ listener: @escaping @Sendable (HeartbeatStatus) -> Void + ) -> RealtimeSubscription { ... } +} +``` + +**Changes:** +- ✅ Internal implementation completely different +- ✅ Public API 100% compatible +- ✅ Better performance +- ✅ More reliable + +### **RealtimeChannel** + +```swift +public final class RealtimeChannel: Sendable { + // Internal state machines - complexity hidden + private let subscriptionMgr: SubscriptionStateMachine + private let eventRegistry: EventHandlerRegistry + private let config: RealtimeChannelConfig + private weak var client: RealtimeClient? + + // Public API - UNCHANGED + public var status: RealtimeChannelStatus { ... } + public var statusChange: AsyncStream { ... } + public let topic: String + + public func subscribe() async { ... } + public func subscribeWithError() async throws { ... } + public func unsubscribe() async { ... } + + public func broadcast(event: String, message: some Codable) async throws { ... } + public func httpSend(event: String, message: some Codable, timeout: TimeInterval?) async throws { ... } + + public func track(_ state: some Codable) async throws { ... } + public func untrack() async { ... } + + public func onPostgresChange( + _: AnyAction.Type, + schema: String = "public", + table: String? = nil, + filter: String? = nil, + callback: @escaping @Sendable (AnyAction) -> Void + ) -> RealtimeSubscription { ... } + + public func onPresenceChange( + _ callback: @escaping @Sendable (any PresenceAction) -> Void + ) -> RealtimeSubscription { ... } + + public func onBroadcast( + event: String, + callback: @escaping @Sendable (JSONObject) -> Void + ) -> RealtimeSubscription { ... } + + public func onStatusChange( + _ listener: @escaping @Sendable (RealtimeChannelStatus) -> Void + ) -> RealtimeSubscription { ... } +} +``` + +**Changes:** +- ✅ Internal implementation refactored +- ✅ Public API 100% compatible +- ✅ More reliable subscription +- ✅ Better error handling + +--- + +## Migration Strategy + +### **Step 1: Create New Components (Non-Breaking)** +**Duration:** 3-5 days + +- ✅ Add new actor-based components alongside existing code +- ✅ Write comprehensive unit tests for each component +- ✅ Keep all existing public APIs unchanged +- ✅ No behavior changes +- ✅ Add feature flags if needed + +**Deliverables:** +- `ConnectionStateMachine` with tests +- `HeartbeatMonitor` with tests +- `AuthTokenManager` with tests +- `MessageRouter` with tests + +### **Step 2: Gradual Internal Migration** +**Duration:** 5-7 days + +- ✅ Replace internal usage incrementally +- ✅ One component at a time +- ✅ Extensive testing at each step +- ✅ Performance benchmarks +- ✅ Manual testing on example apps + +**Order:** +1. Migrate `AuthTokenManager` (lowest risk) +2. Migrate `MessageRouter` (low risk) +3. Migrate `HeartbeatMonitor` (medium risk) +4. Migrate `ConnectionStateMachine` (medium risk) +5. Migrate channel subscription logic (higher risk) + +**Testing at each step:** +- Unit tests pass +- Integration tests pass +- Example apps work +- Performance benchmarks green +- Manual testing on iOS/macOS/etc + +### **Step 3: Deprecate Old Internals** +**Duration:** 2-3 days + +- ✅ Mark old internal methods as deprecated +- ✅ Provide migration guide for advanced users +- ✅ Keep deprecated code for 1-2 releases +- ✅ Add deprecation warnings + +**Example:** +```swift +@available(*, deprecated, message: "This internal method will be removed in v3.0") +internal func oldMethod() { ... } +``` + +### **Step 4: Clean Up** +**Duration:** 1-2 days + +- ✅ Remove deprecated internal code +- ✅ Simplify remaining code +- ✅ Final performance optimization +- ✅ Documentation updates +- ✅ Update examples + +--- + +## Benefits Summary + +### **Maintainability** +| Aspect | Before | After | +|--------|--------|-------| +| Average file size | 600+ LOC | 100-150 LOC | +| Responsibilities per file | 5-8 | 1-2 | +| State complexity | High (shared mutable state) | Low (isolated actors) | +| Time to locate bugs | Hours | Minutes | +| Code comprehension | Difficult | Easy | + +### **Reliability** +| Issue | Before | After | +|-------|--------|-------| +| Connection race conditions | ❌ Possible | ✅ Impossible (state machine) | +| Multiple simultaneous connections | ❌ Can occur | ✅ Cannot occur | +| Invalid state combinations | ❌ Possible | ✅ Prevented by type system | +| Heartbeat timeout bugs | ❌ Recently fixed | ✅ Cannot regress (encapsulated) | +| Task lifecycle bugs | ❌ Common | ✅ Managed by actors | +| Auth token bugs | ❌ Recently fixed | ✅ Single source of truth | + +### **Testability** +| Aspect | Before | After | +|--------|--------|-------| +| Unit test coverage | ~60% | Target: ~85% | +| Mocking difficulty | High | Low (protocols) | +| Test determinism | Flaky (timing) | Deterministic (TestClock) | +| Isolated testing | Difficult | Easy (DI) | +| Test speed | Slow (real timeouts) | Fast (mocked) | + +### **Performance** +| Metric | Before | After | +|--------|--------|-------| +| Lock contention | High (coarse locks) | Low (fine-grained actors) | +| Task overhead | Multiple tasks per operation | Minimal tasks | +| Memory allocations | High (closures) | Reduced (value types) | +| Message routing | O(n) iteration | O(1) lookup | + +### **Developer Experience** +| Aspect | Before | After | +|--------|--------|-------| +| API clarity | Good | Good (unchanged) | +| Error messages | Generic | Specific to state | +| IntelliSense | Works | Better (smaller files) | +| Documentation | Scattered | Grouped by feature | +| Learning curve | Steep | Gradual | + +--- + +## Estimated Effort + +| Phase | Duration | Risk Level | Value | +|-------|----------|------------|-------| +| Phase 1: Core Components | 3-5 days | Low | High | +| Phase 2: Channel Refactor | 5-7 days | Medium | High | +| Phase 3: Testability | 2-3 days | Low | Medium | +| Phase 4: File Organization | 1-2 days | Low | Medium | +| **Total** | **11-17 days** | **Low-Medium** | **High** | + +**Assumptions:** +- One developer working full-time +- Includes comprehensive testing +- Includes code review time +- Includes documentation updates +- Conservative estimates + +**Timeline:** +- Week 1-2: Phase 1 (Core Components) +- Week 2-3: Phase 2 (Channel Refactor) +- Week 3: Phase 3 & 4 (Polish) + +--- + +## Risk Mitigation + +### **Technical Risks** + +| Risk | Probability | Impact | Mitigation | +|------|-------------|--------|------------| +| Breaking changes to public API | Low | High | Maintain 100% backward compatibility | +| Performance regression | Low | Medium | Benchmark at each step | +| New bugs introduced | Medium | Medium | Comprehensive test coverage first | +| Migration takes longer | Medium | Low | Phased approach, can pause anytime | + +### **Mitigation Strategies** + +1. **Maintain 100% backward compatibility** in public API + - All existing code continues to work + - Only internal implementation changes + - Deprecation warnings for internal APIs + +2. **Comprehensive test coverage** before refactoring + - Unit tests for each new component + - Integration tests for end-to-end flows + - Snapshot tests for complex state + +3. **Incremental migration** with feature flags if needed + - Can enable/disable new components + - Rollback easily if issues found + - Gradual rollout to users + +4. **Performance benchmarks** to prevent regressions + - Measure before refactoring + - Compare after each phase + - Automated performance tests + +5. **Extensive manual testing** on example apps + - Test on all platforms (iOS, macOS, tvOS, etc.) + - Real-world usage scenarios + - Edge cases and error conditions + +--- + +## Success Metrics + +### **Code Quality Metrics** + +- ✅ Average file size: < 200 LOC +- ✅ Cyclomatic complexity: < 10 per method +- ✅ Test coverage: > 85% +- ✅ Documentation coverage: 100% public API + +### **Performance Metrics** + +- ✅ Connection time: No regression +- ✅ Message latency: No regression +- ✅ Memory usage: 10-20% reduction +- ✅ CPU usage: No regression + +### **Developer Metrics** + +- ✅ Time to onboard: 50% reduction +- ✅ Bug fix time: 50% reduction +- ✅ Feature development time: 30% reduction +- ✅ Code review time: 40% reduction + +--- + +## Recommendation + +**Start with Phase 1** - it provides the most value with the lowest risk: + +1. ✅ Extract `ConnectionStateMachine` +2. ✅ Extract `HeartbeatMonitor` +3. ✅ Extract `AuthTokenManager` +4. ✅ Extract `MessageRouter` + +### **Why Phase 1 First?** + +**High Value:** +- Eliminates connection race condition **permanently** +- Fixes heartbeat logic complexity **permanently** +- Simplifies auth token handling **permanently** +- Reduces RealtimeClientV2 by ~300 LOC +- Makes all future changes easier + +**Low Risk:** +- No public API changes +- Can be done incrementally +- Easy to test in isolation +- Can rollback if needed +- Minimal integration complexity + +**Quick Wins:** +- Immediate improvement in maintainability +- Better error messages +- Easier debugging +- Foundation for Phase 2 + +### **Next Steps** + +If approved: + +1. Create feature branch `refactor/realtime-phase1` +2. Implement `ConnectionStateMachine` with tests +3. Implement `HeartbeatMonitor` with tests +4. Implement `AuthTokenManager` with tests +5. Implement `MessageRouter` with tests +6. Migrate `RealtimeClientV2` to use new components +7. Comprehensive testing +8. Code review and merge + +**Estimated time for Phase 1: 3-5 days** + +--- + +## Conclusion + +This refactoring proposal addresses the root causes of the bugs recently fixed: + +- **Connection race conditions** → Prevented by `ConnectionStateMachine` +- **Heartbeat timeout bugs** → Eliminated by `HeartbeatMonitor` +- **Auth token bugs** → Fixed by `AuthTokenManager` +- **Message routing complexity** → Simplified by `MessageRouter` +- **State management issues** → Solved by actor isolation +- **Testing difficulties** → Resolved by dependency injection + +The refactoring maintains 100% backward compatibility while significantly improving: +- Code maintainability +- System reliability +- Test coverage +- Developer experience + +**Recommendation: Proceed with Phase 1 implementation.** + +--- + +**Questions or Concerns?** + +Please review and provide feedback on: +1. Overall approach and architecture +2. Specific component designs +3. Migration strategy +4. Timeline and effort estimates +5. Any missing considerations