Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions Sources/Realtime/ConnectionManager.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
//
// ConnectionManager.swift
// Supabase
//
// Created by Guilherme Souza on 19/11/25.
//

import Foundation

actor ConnectionManager {
enum State {
case disconnected
case connecting(Task<Void, any Error>)
case connected(any WebSocket)
case reconnecting(Task<Void, any Error>, reason: String)
}

private let (stateStream, stateContinuation) = AsyncStream<State>.makeStream()
private(set) var state: State = .disconnected

private let transport: WebSocketTransport
private let url: URL
private let headers: [String: String]
private let reconnectDelay: TimeInterval
private let logger: (any SupabaseLogger)?

/// Get current connection if connected, nil otherwise.
var connection: (any WebSocket)? {
if case .connected(let conn) = state {
return conn
}
return nil
}

var stateChanges: AsyncStream<State> { stateStream }

init(
transport: @escaping WebSocketTransport,
url: URL,
headers: [String: String],
reconnectDelay: TimeInterval,
logger: (any SupabaseLogger)?
) {
self.transport = transport
self.url = url
self.headers = headers
self.reconnectDelay = reconnectDelay
self.logger = logger
}

func connect() async throws {
logger?.debug("current state: \(state)")

switch state {
case .connected:
logger?.debug("Already connected")

case .connecting(let task):
logger?.debug("Connection already in progress, waiting...")
try await task.value

case .disconnected:
logger?.debug("Initiating new connection")
try await performConnection()

case .reconnecting(let task, _):
logger?.debug("Reconnection in progress, waiting...")
try await task.value
}
}

func disconnect(reason: String? = nil) {
logger?.debug("current state: \(state)")

switch state {
case .connected(let conn):
logger?.debug("Disconnecting from WebSocket: \(reason ?? "no reason")")
conn.close(code: nil, reason: reason)
updateState(.disconnected)

case .connecting(let task), .reconnecting(let task, _):
logger?.debug("Cancelling connection attempt: \(reason ?? "no reason")")
task.cancel()
updateState(.disconnected)

case .disconnected:
logger?.debug("Already disconnected")
}
}

/// Handle connection error and initiate reconnect.
///
/// - Parameter error: The error that caused the connection failure
func handleError(_ error: any Error) {
guard case .connected = state else {
logger?.debug("Ignoring error in non-connected state: \(error)")
return
}

logger?.debug("Connection error, initiating reconnect: \(error.localizedDescription)")
initiateReconnect(reason: "error: \(error.localizedDescription)")
}

/// Handle connection close.
///
/// - Parameters:
/// - code: WebSocket close code
/// - reason: WebSocket close reason
func handleClose(code: Int?, reason: String?) {
let closeReason = "code: \(code?.description ?? "none"), reason: \(reason ?? "none")"
logger?.debug("Connection closed: \(closeReason)")

disconnect(reason: reason)
}

private func performConnection() async throws {
let connectionTask = Task {
let conn = try await transport(url, headers)
try Task.checkCancellation()
updateState(.connected(conn))
}

updateState(.connecting(connectionTask))

do {
return try await connectionTask.value
} catch {
updateState(.disconnected)
throw error
}
}

private func initiateReconnect(reason: String) {
let reconnectTask = Task {
try await Task.sleep(nanoseconds: UInt64(reconnectDelay * 1_000_000_000))
logger?.debug("Attempting to reconnect...")
try await performConnection()
}

updateState(.reconnecting(reconnectTask, reason: reason))
}

private func updateState(_ state: State) {
self.state = state
self.stateContinuation.yield(state)
}
}
142 changes: 59 additions & 83 deletions Sources/Realtime/RealtimeClientV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ConcurrencyExtras
import Foundation
import Helpers

#if canImport(FoundationNetworking)
import FoundationNetworking
Expand Down Expand Up @@ -42,11 +43,8 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
/// Long-running task for listening for incoming messages from WebSocket.
var messageTask: Task<Void, Never>?

var connectionTask: Task<Void, Never>?
var channels: [String: RealtimeChannelV2] = [:]
var sendBuffer: [@Sendable () -> Void] = []

var conn: (any WebSocket)?
var sendBuffer: [@Sendable (RealtimeClientV2) -> Void] = []
}

let url: URL
Expand All @@ -56,8 +54,10 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
let http: any HTTPClientType
let apikey: String

let connectionManager: ConnectionManager

var conn: (any WebSocket)? {
mutableState.conn
get async { await connectionManager.connection }
}

/// All managed channels indexed by their topics.
Expand Down Expand Up @@ -164,6 +164,18 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
$0.accessToken = String(accessToken)
}
}

self.connectionManager = ConnectionManager(
transport: wsTransport,
url: Self.realtimeWebSocketURL(
baseURL: Self.realtimeBaseURL(url: url),
apikey: options.apikey,
logLevel: options.logLevel
),
headers: options.headers.dictionary,
reconnectDelay: options.reconnectDelay,
logger: options.logger
)
}

deinit {
Expand All @@ -182,58 +194,28 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
}

func connect(reconnect: Bool) async {
if status == .disconnected {
let connectionTask = Task {
if reconnect {
try? await _clock.sleep(for: options.reconnectDelay)
options.logger?.debug(reconnect ? "Reconnecting..." : "Connecting...")

if Task.isCancelled {
options.logger?.debug("Reconnect cancelled, returning")
return
}
}
do {
status = .connecting
try await connectionManager.connect()

if status == .connected {
options.logger?.debug("WebsSocket already connected")
return
}
options.logger?.debug("Connected to realtime WebSocket")

status = .connecting
listenForMessages()
startHeartbeating()

do {
let conn = try await wsTransport(
Self.realtimeWebSocketURL(
baseURL: Self.realtimeBaseURL(url: url),
apikey: options.apikey,
logLevel: options.logLevel
),
options.headers.dictionary
)
mutableState.withValue { $0.conn = conn }
onConnected(reconnect: reconnect)
} catch {
onError(error)
}
}
status = .connected

mutableState.withValue {
$0.connectionTask = connectionTask
if reconnect {
rejoinChannels()
}
}

_ = await statusChange.first { @Sendable in $0 == .connected }
}

private func onConnected(reconnect: Bool) {
status = .connected
options.logger?.debug("Connected to realtime WebSocket")
listenForMessages()
startHeartbeating()
if reconnect {
rejoinChannels()
flushSendBuffer()
} catch {
options.logger?.error("Connection failed: \(error)")
status = .disconnected
}

flushSendBuffer()
}

private func onDisconnected() {
Expand All @@ -244,22 +226,6 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
reconnect()
}

private func onError(_ error: (any Error)?) {
options.logger?
.debug(
"WebSocket error \(error?.localizedDescription ?? "<none>"). Trying again in \(options.reconnectDelay)"
)
reconnect()
}

private func onClose(code: Int?, reason: String?) {
options.logger?.debug(
"WebSocket closed. Code: \(code?.description ?? "<none>"), Reason: \(reason ?? "<none>")"
)

reconnect()
}

private func reconnect(disconnectReason: String? = nil) {
Task {
disconnect(reason: disconnectReason)
Expand Down Expand Up @@ -367,7 +333,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
mutableState.withValue {
$0.messageTask?.cancel()
$0.messageTask = Task { [weak self] in
guard let self, let conn = self.conn else { return }
guard let self, let conn = await self.conn else { return }

do {
for await event in conn.events {
Expand All @@ -387,11 +353,19 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
}

case .close(let code, let reason):
onClose(code: code, reason: reason)
options.logger?.debug(
"WebSocket closed. Code: \(code?.description ?? "<none>"), Reason: \(reason)"
)

await connectionManager.handleClose(code: code, reason: reason)
}
}
} catch {
onError(error)
options.logger?
.debug(
"WebSocket error \(error.localizedDescription). Trying again in \(options.reconnectDelay)"
)
await connectionManager.handleError(error)
}
}
}
Expand Down Expand Up @@ -455,14 +429,14 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
public func disconnect(code: Int? = nil, reason: String? = nil) {
options.logger?.debug("Closing WebSocket connection")

conn?.close(code: code, reason: reason)
Task {
await connectionManager.disconnect(reason: reason ?? "Client disconnect")
}

mutableState.withValue {
$0.ref = 0
$0.messageTask?.cancel()
$0.heartbeatTask?.cancel()
$0.connectionTask?.cancel()
$0.conn = nil
}

status = .disconnected
Expand Down Expand Up @@ -526,27 +500,29 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
///
/// If the socket is not connected, the message gets enqueued within a local buffer, and sent out when a connection is next established.
public func push(_ message: RealtimeMessageV2) {
let callback = { @Sendable [weak self] in
do {
// Check cancellation before sending, because this push may have been cancelled before a connection was established.
try Task.checkCancellation()
let data = try JSONEncoder().encode(message)
self?.conn?.send(String(decoding: data, as: UTF8.self))
} catch {
self?.options.logger?.error(
let callback = { @Sendable (_ client: RealtimeClientV2) in
_ = Task {
do {
// Check cancellation before sending, because this push may have been cancelled before a connection was established.
try Task.checkCancellation()
let data = try JSONEncoder().encode(message)
await client.conn?.send(String(decoding: data, as: UTF8.self))
} catch {
client.options.logger?.error(
"""
Failed to send message:
\(message)

Error:
\(error)
"""
)
)
}
}
}

if status == .connected {
callback()
callback(self)
} else {
mutableState.withValue {
$0.sendBuffer.append(callback)
Expand All @@ -556,7 +532,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {

private func flushSendBuffer() {
mutableState.withValue {
$0.sendBuffer.forEach { $0() }
$0.sendBuffer.forEach { $0(self) }
$0.sendBuffer = []
}
}
Expand Down
Loading
Loading