Skip to content

Commit 72fe1ff

Browse files
committed
fix(realtime): serialize channel subscription
1 parent 248f674 commit 72fe1ff

File tree

3 files changed

+120
-1
lines changed

3 files changed

+120
-1
lines changed

Sources/Realtime/RealtimeChannelV2.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol {
3737
var clientChanges: [PostgresJoinConfig] = []
3838
var joinRef: String?
3939
var pushes: [String: PushV2] = [:]
40+
var subscribeTask: Task<Void, any Error>?
4041
}
4142

4243
@MainActor
@@ -92,7 +93,22 @@ public final class RealtimeChannelV2: Sendable, RealtimeChannelProtocol {
9293
}
9394

9495
/// Subscribes to the channel.
96+
@MainActor
9597
public func subscribeWithError() async throws {
98+
if let subscribeTask = mutableState.subscribeTask {
99+
try await subscribeTask.value
100+
return
101+
}
102+
103+
mutableState.subscribeTask = Task {
104+
defer { self.mutableState.subscribeTask = nil }
105+
try await self.performSubscribeWithRetry()
106+
}
107+
108+
return try await mutableState.subscribeTask!.value
109+
}
110+
111+
private func performSubscribeWithRetry() async throws {
96112
logger?.debug(
97113
"Starting subscription to channel '\(topic)' (attempt 1/\(socket.options.maxRetryAttempts))"
98114
)

Sources/Realtime/RealtimeClientV2.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ public final class RealtimeClientV2: Sendable, RealtimeClientProtocol {
315315
let realtimeTopic = "realtime:\(topic)"
316316

317317
if let channel = $0.channels[realtimeTopic] {
318+
self.options.logger?.debug("Reusing existing channel for topic: \(realtimeTopic)")
318319
return channel
319320
}
320321

Tests/RealtimeTests/RealtimeChannelTests.swift

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// Created by Guilherme Souza on 09/09/24.
66
//
77

8+
import ConcurrencyExtras
89
import InlineSnapshotTesting
910
import TestHelpers
1011
import XCTest
@@ -196,6 +197,58 @@ final class RealtimeChannelTests: XCTestCase {
196197
// The subscription is still in progress when we clean up
197198
}
198199

200+
func testConcurrentSubscribeRunsSingleOperation() async throws {
201+
let socket = ConcurrentSubscribeRealtimeClient()
202+
let channel = RealtimeChannelV2(
203+
topic: "test-topic",
204+
config: RealtimeChannelConfig(
205+
broadcast: BroadcastJoinConfig(),
206+
presence: PresenceJoinConfig(),
207+
isPrivate: false
208+
),
209+
socket: socket,
210+
logger: nil
211+
)
212+
213+
async let firstSubscribe = channel.subscribeWithError()
214+
async let secondSubscribe = channel.subscribeWithError()
215+
216+
try await waitForJoin(
217+
on: socket,
218+
expectedCount: 1
219+
)
220+
XCTAssertEqual(socket.joinPushCount, 1)
221+
222+
await channel.onMessage(
223+
RealtimeMessageV2(
224+
joinRef: nil,
225+
ref: nil,
226+
topic: channel.topic,
227+
event: ChannelEvent.system,
228+
payload: ["status": "ok"]
229+
)
230+
)
231+
232+
try await firstSubscribe
233+
try await secondSubscribe
234+
235+
XCTAssertEqual(socket.joinPushCount, 1)
236+
}
237+
238+
private func waitForJoin(
239+
on socket: ConcurrentSubscribeRealtimeClient,
240+
expectedCount: Int
241+
) async throws {
242+
for _ in 0..<50 {
243+
if socket.joinPushCount == expectedCount {
244+
return
245+
}
246+
try await Task.sleep(nanoseconds: 10_000_000)
247+
}
248+
249+
XCTFail("Timed out waiting for join push")
250+
}
251+
199252
func testHttpSendThrowsWhenAccessTokenIsMissing() async {
200253
let httpClient = HTTPClientMock()
201254
let (client, _) = FakeWebSocket.fakes()
@@ -439,7 +492,9 @@ final class RealtimeChannelTests: XCTestCase {
439492
XCTFail("Expected httpSend to throw an error on 503 status")
440493
} catch {
441494
// Should fall back to localized status text
442-
XCTAssertTrue(error.localizedDescription.contains("503") || error.localizedDescription.contains("unavailable"))
495+
XCTAssertTrue(
496+
error.localizedDescription.contains("503")
497+
|| error.localizedDescription.contains("unavailable"))
443498
}
444499
}
445500
}
@@ -455,3 +510,50 @@ private struct BroadcastPayload: Decodable {
455510
let `private`: Bool
456511
}
457512
}
513+
514+
private final class ConcurrentSubscribeRealtimeClient: RealtimeClientProtocol, @unchecked Sendable {
515+
private let _pushedMessages = LockIsolated<[RealtimeMessageV2]>([])
516+
private let _status = LockIsolated<RealtimeClientStatus>(.connected)
517+
518+
let options: RealtimeClientOptions
519+
let http: any HTTPClientType
520+
let broadcastURL = URL(string: "https://localhost:54321/realtime/v1/api/broadcast")!
521+
522+
init() {
523+
self.options = RealtimeClientOptions(
524+
headers: ["apikey": "test-key"],
525+
timeoutInterval: 5.0
526+
)
527+
self.http = HTTPClientMock()
528+
}
529+
530+
var status: RealtimeClientStatus {
531+
_status.value
532+
}
533+
534+
var pushedMessages: [RealtimeMessageV2] {
535+
_pushedMessages.value
536+
}
537+
538+
var joinPushCount: Int {
539+
pushedMessages.filter { $0.event == ChannelEvent.join }.count
540+
}
541+
542+
func connect() async {
543+
_status.setValue(.connected)
544+
}
545+
546+
func push(_ message: RealtimeMessageV2) {
547+
_pushedMessages.withValue { $0.append(message) }
548+
}
549+
550+
func _getAccessToken() async -> String? {
551+
nil
552+
}
553+
554+
func makeRef() -> String {
555+
UUID().uuidString
556+
}
557+
558+
func _remove(_: any RealtimeChannelProtocol) {}
559+
}

0 commit comments

Comments
 (0)