Skip to content
87 changes: 87 additions & 0 deletions Sources/Realtime/Auth/AuthTokenManager.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//
// AuthTokenManager.swift
// Realtime
//
// Created on 17/01/25.
//

import Foundation

/// Manages authentication token lifecycle and distribution.
///
/// This actor provides a single source of truth for the current authentication token,
/// handling both direct token assignment and token provider callbacks.
actor AuthTokenManager {
// MARK: - Properties

private var currentToken: String?
private let tokenProvider: (@Sendable () async throws -> String?)?

// MARK: - Initialization

init(
initialToken: String?,
tokenProvider: (@Sendable () async throws -> String?)?
) {
self.currentToken = initialToken
self.tokenProvider = tokenProvider
}

// MARK: - Public API

/// Get current token, calling provider if needed.
///
/// If no current token is set, this will attempt to fetch from the token provider.
///
/// - Returns: The current authentication token, or nil if unavailable
func getCurrentToken() async -> String? {
// Return current token if available
if let token = currentToken {
return token
}

// Try to get from provider
if let provider = tokenProvider {
let token = try? await provider()
currentToken = token
return token
}

return nil
}

/// Update token and return if it changed.
///
/// - Parameter token: The new token to set, or nil to clear
/// - Returns: True if the token changed, false if it's the same
func updateToken(_ token: String?) async -> Bool {
guard token != currentToken else {
return false
}

currentToken = token
return true
}

/// Refresh token from provider if available.
///
/// This forces a call to the token provider even if a current token exists.
///
/// - Returns: The refreshed token, or current token if no provider
func refreshToken() async -> String? {
guard let provider = tokenProvider else {
return currentToken
}

let token = try? await provider()
currentToken = token
return token
}

/// Get the current token without calling the provider.
///
/// - Returns: The currently stored token, or nil
var token: String? {
currentToken
}
}
191 changes: 191 additions & 0 deletions Sources/Realtime/Connection/ConnectionStateMachine.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//
// ConnectionStateMachine.swift
// Realtime
//
// Created on 17/01/25.
//

import Foundation
import Helpers

/// Manages WebSocket connection lifecycle with clear state transitions.
///
/// This actor ensures thread-safe connection management and prevents race conditions
/// by enforcing valid state transitions through Swift's type system.
actor ConnectionStateMachine {
/// Represents the possible states of a WebSocket connection
enum State: Sendable {
case disconnected
case connecting(Task<Void, any Error>)
case connected(any WebSocket)
case reconnecting(Task<Void, any Error>, reason: String)
}

// MARK: - Properties

private(set) var state: State = .disconnected

private let transport: WebSocketTransport
private let url: URL
private let headers: [String: String]
private let reconnectDelay: TimeInterval
private let logger: (any SupabaseLogger)?

// MARK: - Initialization

init(
transport: @escaping WebSocketTransport,
url: URL,
headers: [String: String],
reconnectDelay: TimeInterval,
logger: (any SupabaseLogger)?
) {
self.transport = transport
self.url = url
self.headers = headers
self.reconnectDelay = reconnectDelay
self.logger = logger
}

// MARK: - Public API

/// Connect to WebSocket. Returns existing connection if already connected.
///
/// This method is safe to call multiple times - it will reuse an existing connection
/// or wait for an in-progress connection attempt to complete.
///
/// - Returns: The active WebSocket connection
/// - Throws: Connection errors from the transport layer
func connect() async throws -> any WebSocket {
switch state {
case .connected(let conn):
logger?.debug("Already connected to WebSocket")
return conn

case .connecting(let task):
logger?.debug("Connection already in progress, waiting...")
try await task.value
// Recursively call to get the connection after task completes
return try await connect()

case .reconnecting(let task, _):
logger?.debug("Reconnection in progress, waiting...")
try await task.value
return try await connect()

case .disconnected:
logger?.debug("Initiating new connection")
return try await performConnection()
}
}

/// Disconnect and clean up resources.
///
/// - Parameter reason: Optional reason for disconnection
func disconnect(reason: String? = nil) {
switch state {
case .connected(let conn):
logger?.debug("Disconnecting from WebSocket: \(reason ?? "no reason")")
conn.close(code: nil, reason: reason)
state = .disconnected

case .connecting(let task), .reconnecting(let task, _):
logger?.debug("Cancelling connection attempt: \(reason ?? "no reason")")
task.cancel()
state = .disconnected

case .disconnected:
logger?.debug("Already disconnected")
}
}

/// Handle connection error and initiate reconnect.
///
/// - Parameter error: The error that caused the connection failure
func handleError(_ error: any Error) {
guard case .connected = state else {
logger?.debug("Ignoring error in non-connected state: \(error)")
return
}

logger?.debug("Connection error, initiating reconnect: \(error.localizedDescription)")
initiateReconnect(reason: "error: \(error.localizedDescription)")
}

/// Handle connection close.
///
/// - Parameters:
/// - code: WebSocket close code
/// - reason: WebSocket close reason
func handleClose(code: Int?, reason: String?) {
let closeReason = "code: \(code?.description ?? "none"), reason: \(reason ?? "none")"
logger?.debug("Connection closed: \(closeReason)")

disconnect(reason: reason)
}

/// Handle disconnection event and initiate reconnect.
func handleDisconnected() {
guard case .connected = state else { return }

logger?.debug("Connection disconnected, initiating reconnect")
initiateReconnect(reason: "disconnected")
}

/// Get current connection if connected, nil otherwise.
var connection: (any WebSocket)? {
if case .connected(let conn) = state {
return conn
}
return nil
}

/// Check if currently connected.
var isConnected: Bool {
if case .connected = state {
return true
}
return false
}

// MARK: - Private Helpers

private func performConnection() async throws -> any WebSocket {
let connectionTask = Task<Void, any Error> {
let conn = try await transport(url, headers)
state = .connected(conn)
}

state = .connecting(connectionTask)

do {
try await connectionTask.value

// Get the connection that was just set
guard case .connected(let conn) = state else {
throw RealtimeError("Connection succeeded but state is invalid")
}

return conn
} catch {
state = .disconnected
throw error
}
}

private func initiateReconnect(reason: String) {
let reconnectTask = Task<Void, any Error> {
try await Task.sleep(nanoseconds: UInt64(reconnectDelay * 1_000_000_000))

if Task.isCancelled {
logger?.debug("Reconnect cancelled")
return
}

logger?.debug("Attempting to reconnect...")
_ = try await performConnection()
}

state = .reconnecting(reconnectTask, reason: reason)
}
}
113 changes: 113 additions & 0 deletions Sources/Realtime/Connection/HeartbeatMonitor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//
// HeartbeatMonitor.swift
// Realtime
//
// Created on 17/01/25.
//

import Foundation
import Helpers

/// Manages heartbeat send/receive cycle with timeout detection.
///
/// This actor encapsulates all heartbeat logic, ensuring that heartbeats are sent
/// at regular intervals and timeouts are detected when responses aren't received.
actor HeartbeatMonitor {
// MARK: - Properties

private let interval: TimeInterval
private let refGenerator: @Sendable () -> String
private let sendHeartbeat: @Sendable (String) async -> Void
private let onTimeout: @Sendable () async -> Void
private let logger: (any SupabaseLogger)?

private var monitorTask: Task<Void, Never>?
private var pendingRef: String?

// MARK: - Initialization

init(
interval: TimeInterval,
refGenerator: @escaping @Sendable () -> String,
sendHeartbeat: @escaping @Sendable (String) async -> Void,
onTimeout: @escaping @Sendable () async -> Void,
logger: (any SupabaseLogger)?
) {
self.interval = interval
self.refGenerator = refGenerator
self.sendHeartbeat = sendHeartbeat
self.onTimeout = onTimeout
self.logger = logger
}

// MARK: - Public API

/// Start heartbeat monitoring.
///
/// Sends heartbeats at regular intervals and detects timeouts when responses
/// aren't received before the next interval.
func start() {
stop() // Cancel any existing monitor

logger?.debug("Starting heartbeat monitor with interval: \(interval)")

monitorTask = Task {
while !Task.isCancelled {
do {
try await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
} catch {
// Task cancelled during sleep
break
}

if Task.isCancelled { break }

await sendNextHeartbeat()
}

logger?.debug("Heartbeat monitor stopped")
}
}

/// Stop heartbeat monitoring.
func stop() {
if monitorTask != nil {
logger?.debug("Stopping heartbeat monitor")
monitorTask?.cancel()
monitorTask = nil
pendingRef = nil
}
}

/// Called when heartbeat response is received.
///
/// - Parameter ref: The reference ID from the heartbeat response
func onHeartbeatResponse(ref: String) {
guard let pending = pendingRef, pending == ref else {
logger?.debug("Received heartbeat response with mismatched ref: \(ref)")
return
}

logger?.debug("Heartbeat acknowledged: \(ref)")
pendingRef = nil
}

// MARK: - Private Helpers

private func sendNextHeartbeat() async {
// Check if previous heartbeat was acknowledged
if let pending = pendingRef {
logger?.debug("Heartbeat timeout - previous heartbeat not acknowledged: \(pending)")
pendingRef = nil
await onTimeout()
return
}

// Send new heartbeat
let ref = refGenerator()
pendingRef = ref

logger?.debug("Sending heartbeat: \(ref)")
await sendHeartbeat(ref)
}
}
Loading
Loading