diff --git a/Sources/Realtime/ConnectionManager.swift b/Sources/Realtime/ConnectionManager.swift new file mode 100644 index 00000000..59529885 --- /dev/null +++ b/Sources/Realtime/ConnectionManager.swift @@ -0,0 +1,147 @@ +// +// ConnectionManager.swift +// Supabase +// +// Created by Guilherme Souza on 19/11/25. +// + +import Foundation + +actor ConnectionManager { + enum State { + case disconnected + case connecting(Task) + case connected(any WebSocket) + case reconnecting(Task, reason: String) + } + + private let (stateStream, stateContinuation) = AsyncStream.makeStream() + private(set) var state: State = .disconnected + + private let transport: WebSocketTransport + private let url: URL + private let headers: [String: String] + private let reconnectDelay: TimeInterval + private let logger: (any SupabaseLogger)? + + /// Get current connection if connected, nil otherwise. + var connection: (any WebSocket)? { + if case .connected(let conn) = state { + return conn + } + return nil + } + + var stateChanges: AsyncStream { stateStream } + + init( + transport: @escaping WebSocketTransport, + url: URL, + headers: [String: String], + reconnectDelay: TimeInterval, + logger: (any SupabaseLogger)? + ) { + self.transport = transport + self.url = url + self.headers = headers + self.reconnectDelay = reconnectDelay + self.logger = logger + } + + func connect() async throws { + logger?.debug("current state: \(state)") + + switch state { + case .connected: + logger?.debug("Already connected") + + case .connecting(let task): + logger?.debug("Connection already in progress, waiting...") + try await task.value + + case .disconnected: + logger?.debug("Initiating new connection") + try await performConnection() + + case .reconnecting(let task, _): + logger?.debug("Reconnection in progress, waiting...") + try await task.value + } + } + + func disconnect(reason: String? = nil) { + logger?.debug("current state: \(state)") + + switch state { + case .connected(let conn): + logger?.debug("Disconnecting from WebSocket: \(reason ?? "no reason")") + conn.close(code: nil, reason: reason) + updateState(.disconnected) + + case .connecting(let task), .reconnecting(let task, _): + logger?.debug("Cancelling connection attempt: \(reason ?? "no reason")") + task.cancel() + updateState(.disconnected) + + case .disconnected: + logger?.debug("Already disconnected") + } + } + + /// Handle connection error and initiate reconnect. + /// + /// - Parameter error: The error that caused the connection failure + func handleError(_ error: any Error) { + guard case .connected = state else { + logger?.debug("Ignoring error in non-connected state: \(error)") + return + } + + logger?.debug("Connection error, initiating reconnect: \(error.localizedDescription)") + initiateReconnect(reason: "error: \(error.localizedDescription)") + } + + /// Handle connection close. + /// + /// - Parameters: + /// - code: WebSocket close code + /// - reason: WebSocket close reason + func handleClose(code: Int?, reason: String?) { + let closeReason = "code: \(code?.description ?? "none"), reason: \(reason ?? "none")" + logger?.debug("Connection closed: \(closeReason)") + + disconnect(reason: reason) + } + + private func performConnection() async throws { + let connectionTask = Task { + let conn = try await transport(url, headers) + try Task.checkCancellation() + updateState(.connected(conn)) + } + + updateState(.connecting(connectionTask)) + + do { + return try await connectionTask.value + } catch { + updateState(.disconnected) + throw error + } + } + + private func initiateReconnect(reason: String) { + let reconnectTask = Task { + try await Task.sleep(nanoseconds: UInt64(reconnectDelay * 1_000_000_000)) + logger?.debug("Attempting to reconnect...") + try await performConnection() + } + + updateState(.reconnecting(reconnectTask, reason: reason)) + } + + private func updateState(_ state: State) { + self.state = state + self.stateContinuation.yield(state) + } +} diff --git a/Sources/Realtime/RealtimeClientV2.swift b/Sources/Realtime/RealtimeClientV2.swift index f99c5173..57a4bebc 100644 --- a/Sources/Realtime/RealtimeClientV2.swift +++ b/Sources/Realtime/RealtimeClientV2.swift @@ -7,6 +7,7 @@ import ConcurrencyExtras import Foundation +import Helpers #if canImport(FoundationNetworking) import FoundationNetworking @@ -42,11 +43,8 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { /// Long-running task for listening for incoming messages from WebSocket. var messageTask: Task? - var connectionTask: Task? var channels: [String: RealtimeChannelV2] = [:] - var sendBuffer: [@Sendable () -> Void] = [] - - var conn: (any WebSocket)? + var sendBuffer: [@Sendable (RealtimeClientV2) -> Void] = [] } let url: URL @@ -56,8 +54,10 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { let http: any HTTPClientType let apikey: String + let connectionManager: ConnectionManager + var conn: (any WebSocket)? { - mutableState.conn + get async { await connectionManager.connection } } /// All managed channels indexed by their topics. @@ -164,6 +164,18 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { $0.accessToken = String(accessToken) } } + + self.connectionManager = ConnectionManager( + transport: wsTransport, + url: Self.realtimeWebSocketURL( + baseURL: Self.realtimeBaseURL(url: url), + apikey: options.apikey, + logLevel: options.logLevel + ), + headers: options.headers.dictionary, + reconnectDelay: options.reconnectDelay, + logger: options.logger + ) } deinit { @@ -182,58 +194,28 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { } func connect(reconnect: Bool) async { - if status == .disconnected { - let connectionTask = Task { - if reconnect { - try? await _clock.sleep(for: options.reconnectDelay) + options.logger?.debug(reconnect ? "Reconnecting..." : "Connecting...") - if Task.isCancelled { - options.logger?.debug("Reconnect cancelled, returning") - return - } - } + do { + status = .connecting + try await connectionManager.connect() - if status == .connected { - options.logger?.debug("WebsSocket already connected") - return - } + options.logger?.debug("Connected to realtime WebSocket") - status = .connecting + listenForMessages() + startHeartbeating() - do { - let conn = try await wsTransport( - Self.realtimeWebSocketURL( - baseURL: Self.realtimeBaseURL(url: url), - apikey: options.apikey, - logLevel: options.logLevel - ), - options.headers.dictionary - ) - mutableState.withValue { $0.conn = conn } - onConnected(reconnect: reconnect) - } catch { - onError(error) - } - } + status = .connected - mutableState.withValue { - $0.connectionTask = connectionTask + if reconnect { + rejoinChannels() } - } - _ = await statusChange.first { @Sendable in $0 == .connected } - } - - private func onConnected(reconnect: Bool) { - status = .connected - options.logger?.debug("Connected to realtime WebSocket") - listenForMessages() - startHeartbeating() - if reconnect { - rejoinChannels() + flushSendBuffer() + } catch { + options.logger?.error("Connection failed: \(error)") + status = .disconnected } - - flushSendBuffer() } private func onDisconnected() { @@ -244,22 +226,6 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { reconnect() } - private func onError(_ error: (any Error)?) { - options.logger? - .debug( - "WebSocket error \(error?.localizedDescription ?? ""). Trying again in \(options.reconnectDelay)" - ) - reconnect() - } - - private func onClose(code: Int?, reason: String?) { - options.logger?.debug( - "WebSocket closed. Code: \(code?.description ?? ""), Reason: \(reason ?? "")" - ) - - reconnect() - } - private func reconnect(disconnectReason: String? = nil) { Task { disconnect(reason: disconnectReason) @@ -367,7 +333,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { mutableState.withValue { $0.messageTask?.cancel() $0.messageTask = Task { [weak self] in - guard let self, let conn = self.conn else { return } + guard let self, let conn = await self.conn else { return } do { for await event in conn.events { @@ -387,11 +353,19 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { } case .close(let code, let reason): - onClose(code: code, reason: reason) + options.logger?.debug( + "WebSocket closed. Code: \(code?.description ?? ""), Reason: \(reason)" + ) + + await connectionManager.handleClose(code: code, reason: reason) } } } catch { - onError(error) + options.logger? + .debug( + "WebSocket error \(error.localizedDescription). Trying again in \(options.reconnectDelay)" + ) + await connectionManager.handleError(error) } } } @@ -455,14 +429,14 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { public func disconnect(code: Int? = nil, reason: String? = nil) { options.logger?.debug("Closing WebSocket connection") - conn?.close(code: code, reason: reason) + Task { + await connectionManager.disconnect(reason: reason ?? "Client disconnect") + } mutableState.withValue { $0.ref = 0 $0.messageTask?.cancel() $0.heartbeatTask?.cancel() - $0.connectionTask?.cancel() - $0.conn = nil } status = .disconnected @@ -526,27 +500,29 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { /// /// If the socket is not connected, the message gets enqueued within a local buffer, and sent out when a connection is next established. public func push(_ message: RealtimeMessageV2) { - let callback = { @Sendable [weak self] in - do { - // Check cancellation before sending, because this push may have been cancelled before a connection was established. - try Task.checkCancellation() - let data = try JSONEncoder().encode(message) - self?.conn?.send(String(decoding: data, as: UTF8.self)) - } catch { - self?.options.logger?.error( + let callback = { @Sendable (_ client: RealtimeClientV2) in + _ = Task { + do { + // Check cancellation before sending, because this push may have been cancelled before a connection was established. + try Task.checkCancellation() + let data = try JSONEncoder().encode(message) + await client.conn?.send(String(decoding: data, as: UTF8.self)) + } catch { + client.options.logger?.error( """ Failed to send message: \(message) - + Error: \(error) """ - ) + ) + } } } if status == .connected { - callback() + callback(self) } else { mutableState.withValue { $0.sendBuffer.append(callback) @@ -556,7 +532,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol { private func flushSendBuffer() { mutableState.withValue { - $0.sendBuffer.forEach { $0() } + $0.sendBuffer.forEach { $0(self) } $0.sendBuffer = [] } } diff --git a/Tests/RealtimeTests/ConnectionManagerTests.swift b/Tests/RealtimeTests/ConnectionManagerTests.swift new file mode 100644 index 00000000..1f45c2e0 --- /dev/null +++ b/Tests/RealtimeTests/ConnectionManagerTests.swift @@ -0,0 +1,238 @@ +// +// ConnectionManagerTests.swift +// Supabase +// +// Created by Guilherme Souza on 19/11/25. +// + +import ConcurrencyExtras +import XCTest + +@testable import Realtime + +final class ConnectionManagerTests: XCTestCase { + private enum TestError: LocalizedError { + case sample + + var errorDescription: String? { "sample error" } + } + + var sut: ConnectionManager! + var ws: FakeWebSocket! + var transportCallCount = 0 + var lastConnectURL: URL? + var lastConnectHeaders: [String: String]? + + override func setUp() { + super.setUp() + + transportCallCount = 0 + lastConnectURL = nil + lastConnectHeaders = nil + (ws, _) = FakeWebSocket.fakes() + } + + override func tearDown() { + sut = nil + ws = nil + super.tearDown() + } + + private func makeSUT( + url: URL = URL(string: "ws://localhost")!, + headers: [String: String] = [:], + reconnectDelay: TimeInterval = 0.1, + transport: WebSocketTransport? = nil + ) -> ConnectionManager { + ConnectionManager( + transport: transport ?? { url, headers in + self.transportCallCount += 1 + self.lastConnectURL = url + self.lastConnectHeaders = headers + return self.ws! + }, + url: url, + headers: headers, + reconnectDelay: reconnectDelay, + logger: nil + ) + } + + func testConnectTransitionsThroughConnectingAndConnectedStates() async throws { + sut = makeSUT(headers: ["apikey": "key"]) + + let connectingExpectation = expectation(description: "connecting state observed") + let connectedExpectation = expectation(description: "connected state observed") + + let stateObserver = Task { + for await state in await sut.stateChanges { + switch state { + case .connecting: + connectingExpectation.fulfill() + case .connected: + connectedExpectation.fulfill() + return + default: + break + } + } + } + + let initiallyConnected = await sut.connection != nil + XCTAssertFalse(initiallyConnected) + try await sut.connect() + + let isConnected = await sut.connection != nil + XCTAssertTrue(isConnected) + XCTAssertEqual(transportCallCount, 1) + XCTAssertEqual(lastConnectURL?.absoluteString, "ws://localhost") + XCTAssertEqual(lastConnectHeaders, ["apikey": "key"]) + + await fulfillment(of: [connectingExpectation, connectedExpectation], timeout: 1) + stateObserver.cancel() + } + + func testConnectWhenAlreadyConnectedDoesNotReconnect() async throws { + sut = makeSUT() + + try await sut.connect() + XCTAssertEqual(transportCallCount, 1) + + try await sut.connect() + + let stillConnected = await sut.connection != nil + XCTAssertTrue(stillConnected) + XCTAssertEqual(transportCallCount, 1, "Second connect should reuse existing connection") + } + + func testConnectWhileConnectingWaitsForExistingTask() async throws { + sut = makeSUT( + transport: { _, _ in + self.transportCallCount += 1 + try await Task.sleep(nanoseconds: 200_000_000) + return self.ws! + } + ) + + let firstConnect = Task { + try await sut.connect() + } + + let secondConnectFinished = LockIsolated(false) + let secondConnect = Task { + try await sut.connect() + secondConnectFinished.setValue(true) + } + + try await Task.sleep(nanoseconds: 50_000_000) + XCTAssertFalse(secondConnectFinished.value) + XCTAssertEqual( + transportCallCount, 1, + "Transport should be invoked only once while first connect is in progress") + + try await firstConnect.value + try await secondConnect.value + + XCTAssertTrue(secondConnectFinished.value) + let isConnected = await sut.connection != nil + XCTAssertTrue(isConnected) + XCTAssertEqual(transportCallCount, 1) + } + + func testDisconnectFromConnectedClosesWebSocketAndUpdatesState() async throws { + sut = makeSUT() + try await sut.connect() + + await sut.disconnect(reason: "test reason") + + let isConnected = await sut.connection != nil + XCTAssertFalse(isConnected) + guard case .close(let closeCode, let closeReason)? = ws.sentEvents.last else { + return XCTFail("Expected close event to be sent") + } + XCTAssertNil(closeCode) + XCTAssertEqual(closeReason, "test reason") + } + + func testDisconnectCancelsOngoingConnectionAttempt() async throws { + let wasCancelled = LockIsolated(false) + + sut = makeSUT( + transport: { _, _ in + self.transportCallCount += 1 + return try await withTaskCancellationHandler { + try await Task.sleep(nanoseconds: 5_000_000_000) + return self.ws! + } onCancel: { + wasCancelled.setValue(true) + } + } + ) + + let connectTask = Task { + try? await sut.connect() + } + + try await Task.sleep(nanoseconds: 50_000_000) + await sut.disconnect(reason: "stop") + + await Task.yield() + XCTAssertTrue(wasCancelled.value, "Cancellation handler should run when disconnecting") + let isConnected = await sut.connection != nil + XCTAssertFalse(isConnected) + + connectTask.cancel() + } + + func testHandleErrorInitiatesReconnectAndEventuallyReconnects() async throws { + let reconnectingExpectation = expectation(description: "reconnecting state observed") + let secondConnectionExpectation = expectation(description: "second connection attempt") + + let connectionCount = LockIsolated(0) + + sut = makeSUT( + reconnectDelay: 0.01, + transport: { _, _ in + connectionCount.withValue { $0 += 1 } + if connectionCount.value == 2 { + secondConnectionExpectation.fulfill() + } + return self.ws! + } + ) + + let stateObserver = Task { + for await state in await sut.stateChanges { + if case .reconnecting(_, let reason) = state, reason.contains("sample error") { + reconnectingExpectation.fulfill() + return + } + } + } + + try await sut.connect() + await sut.handleError(TestError.sample) + + await fulfillment(of: [reconnectingExpectation, secondConnectionExpectation], timeout: 2) + XCTAssertEqual(connectionCount.value, 2, "Reconnection should trigger a second transport call") + let isConnected = await sut.connection != nil + XCTAssertTrue(isConnected) + + stateObserver.cancel() + } + + func testHandleCloseDelegatesToDisconnect() async throws { + sut = makeSUT() + try await sut.connect() + + await sut.handleClose(code: 4001, reason: "server closing") + + let isConnected = await sut.connection != nil + XCTAssertFalse(isConnected) + guard case .close(let closeCode, let closeReason)? = ws.sentEvents.last else { + return XCTFail("Expected close event to be sent") + } + XCTAssertNil(closeCode) + XCTAssertEqual(closeReason, "server closing") + } +} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 055cb1c0..39a04be1 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -91,130 +91,122 @@ import XCTest await client.connect() } -// func testBehavior() async throws { -// let channel = sut.channel("public:messages") -// var subscriptions: Set = [] -// -// channel.onPostgresChange(InsertAction.self, table: "messages") { _ in -// } -// .store(in: &subscriptions) -// -// channel.onPostgresChange(UpdateAction.self, table: "messages") { _ in -// } -// .store(in: &subscriptions) -// -// channel.onPostgresChange(DeleteAction.self, table: "messages") { _ in -// } -// .store(in: &subscriptions) -// -// let socketStatuses = LockIsolated([RealtimeClientStatus]()) -// -// sut.onStatusChange { status in -// socketStatuses.withValue { $0.append(status) } -// } -// .store(in: &subscriptions) -// -// // Set up server to respond to heartbeats -// server.onEvent = { @Sendable [server] event in -// guard let msg = event.realtimeMessage else { return } -// -// if msg.event == "heartbeat" { -// server?.send( -// RealtimeMessageV2( -// joinRef: msg.joinRef, -// ref: msg.ref, -// topic: "phoenix", -// event: "phx_reply", -// payload: ["response": [:]] -// ) -// ) -// } -// } -// -// await waitUntil { -// socketStatuses.value.count >= 3 -// } -// -// XCTAssertEqual( -// Array(socketStatuses.value.prefix(3)), -// [.disconnected, .connecting, .connected] -// ) -// -// let messageTask = sut.mutableState.messageTask -// XCTAssertNotNil(messageTask) -// -// let heartbeatTask = sut.mutableState.heartbeatTask -// XCTAssertNotNil(heartbeatTask) -// -// let channelStatuses = LockIsolated([RealtimeChannelStatus]()) -// channel.onStatusChange { status in -// channelStatuses.withValue { -// $0.append(status) -// } -// } -// .store(in: &subscriptions) -// -// let subscribeTask = Task { -// try await channel.subscribeWithError() -// } -// await Task.yield() -// server.send(.messagesSubscribed) -// -// // Wait until it subscribes to assert WS events -// do { -// try await subscribeTask.value -// } catch { -// XCTFail("Expected .subscribed but got error: \(error)") -// } -// XCTAssertEqual(channelStatuses.value, [.unsubscribed, .subscribing, .subscribed]) -// -// assertInlineSnapshot(of: client.sentEvents.map(\.json), as: .json) { -// #""" -// [ -// { -// "text" : { -// "event" : "phx_join", -// "join_ref" : "1", -// "payload" : { -// "access_token" : "custom.access.token", -// "config" : { -// "broadcast" : { -// "ack" : false, -// "self" : false -// }, -// "postgres_changes" : [ -// { -// "event" : "INSERT", -// "schema" : "public", -// "table" : "messages" -// }, -// { -// "event" : "UPDATE", -// "schema" : "public", -// "table" : "messages" -// }, -// { -// "event" : "DELETE", -// "schema" : "public", -// "table" : "messages" -// } -// ], -// "presence" : { -// "enabled" : false, -// "key" : "" -// }, -// "private" : false -// }, -// "version" : "realtime-swift\/0.0.0" -// }, -// "ref" : "1", -// "topic" : "realtime:public:messages" -// } -// } -// ] -// """# -// } -// } + func testBehavior() async throws { + let channel = sut.channel("public:messages") + var subscriptions: Set = [] + + channel.onPostgresChange(InsertAction.self, table: "messages") { _ in + } + .store(in: &subscriptions) + + channel.onPostgresChange(UpdateAction.self, table: "messages") { _ in + } + .store(in: &subscriptions) + + channel.onPostgresChange(DeleteAction.self, table: "messages") { _ in + } + .store(in: &subscriptions) + + let socketStatuses = LockIsolated([RealtimeClientStatus]()) + + sut.onStatusChange { status in + socketStatuses.withValue { $0.append(status) } + } + .store(in: &subscriptions) + + // Set up server to respond to heartbeats + server.onEvent = { @Sendable [server] event in + guard let msg = event.realtimeMessage else { return } + + if msg.event == "heartbeat" { + server?.send( + RealtimeMessageV2( + joinRef: msg.joinRef, + ref: msg.ref, + topic: "phoenix", + event: "phx_reply", + payload: ["response": [:]] + ) + ) + } else if msg.event == "phx_join" { + server?.send(.messagesSubscribed) + } + } + + let channelStatuses = LockIsolated([RealtimeChannelStatus]()) + channel.onStatusChange { status in + channelStatuses.withValue { + $0.append(status) + } + } + .store(in: &subscriptions) + + // Wait until it subscribes to assert WS events + do { + try await channel.subscribeWithError() + } catch { + XCTFail("Expected .subscribed but got error: \(error)") + } + XCTAssertEqual(channelStatuses.value, [.unsubscribed, .subscribing, .subscribed]) + + XCTAssertEqual( + Array(socketStatuses.value.prefix(3)), + [.disconnected, .connecting, .connected] + ) + + let messageTask = sut.mutableState.messageTask + XCTAssertNotNil(messageTask) + + let heartbeatTask = sut.mutableState.heartbeatTask + XCTAssertNotNil(heartbeatTask) + + assertInlineSnapshot(of: client.sentEvents.map(\.json), as: .json) { + #""" + [ + { + "text" : { + "event" : "phx_join", + "join_ref" : "1", + "payload" : { + "access_token" : "custom.access.token", + "config" : { + "broadcast" : { + "ack" : false, + "self" : false + }, + "postgres_changes" : [ + { + "event" : "INSERT", + "schema" : "public", + "table" : "messages" + }, + { + "event" : "UPDATE", + "schema" : "public", + "table" : "messages" + }, + { + "event" : "DELETE", + "schema" : "public", + "table" : "messages" + } + ], + "presence" : { + "enabled" : false, + "key" : "" + }, + "private" : false + }, + "version" : "realtime-swift\/0.0.0" + }, + "ref" : "1", + "topic" : "realtime:public:messages" + } + } + ] + """# + } + } func testSubscribeTimeout() async throws { let channel = sut.channel("public:messages")