55// Created by Guilherme Souza on 09/09/24.
66//
77
8+ import ConcurrencyExtras
89import InlineSnapshotTesting
910import TestHelpers
1011import XCTest
@@ -196,6 +197,58 @@ final class RealtimeChannelTests: XCTestCase {
196197 // The subscription is still in progress when we clean up
197198 }
198199
200+ func testConcurrentSubscribeRunsSingleOperation( ) async throws {
201+ let socket = ConcurrentSubscribeRealtimeClient ( )
202+ let channel = RealtimeChannelV2 (
203+ topic: " test-topic " ,
204+ config: RealtimeChannelConfig (
205+ broadcast: BroadcastJoinConfig ( ) ,
206+ presence: PresenceJoinConfig ( ) ,
207+ isPrivate: false
208+ ) ,
209+ socket: socket,
210+ logger: nil
211+ )
212+
213+ async let firstSubscribe = channel. subscribeWithError ( )
214+ async let secondSubscribe = channel. subscribeWithError ( )
215+
216+ try await waitForJoin (
217+ on: socket,
218+ expectedCount: 1
219+ )
220+ XCTAssertEqual ( socket. joinPushCount, 1 )
221+
222+ await channel. onMessage (
223+ RealtimeMessageV2 (
224+ joinRef: nil ,
225+ ref: nil ,
226+ topic: channel. topic,
227+ event: ChannelEvent . system,
228+ payload: [ " status " : " ok " ]
229+ )
230+ )
231+
232+ try await firstSubscribe
233+ try await secondSubscribe
234+
235+ XCTAssertEqual ( socket. joinPushCount, 1 )
236+ }
237+
238+ private func waitForJoin(
239+ on socket: ConcurrentSubscribeRealtimeClient ,
240+ expectedCount: Int
241+ ) async throws {
242+ for _ in 0 ..< 50 {
243+ if socket. joinPushCount == expectedCount {
244+ return
245+ }
246+ try await Task . sleep ( nanoseconds: 10_000_000 )
247+ }
248+
249+ XCTFail ( " Timed out waiting for join push " )
250+ }
251+
199252 func testHttpSendThrowsWhenAccessTokenIsMissing( ) async {
200253 let httpClient = HTTPClientMock ( )
201254 let ( client, _) = FakeWebSocket . fakes ( )
@@ -439,7 +492,9 @@ final class RealtimeChannelTests: XCTestCase {
439492 XCTFail ( " Expected httpSend to throw an error on 503 status " )
440493 } catch {
441494 // Should fall back to localized status text
442- XCTAssertTrue ( error. localizedDescription. contains ( " 503 " ) || error. localizedDescription. contains ( " unavailable " ) )
495+ XCTAssertTrue (
496+ error. localizedDescription. contains ( " 503 " )
497+ || error. localizedDescription. contains ( " unavailable " ) )
443498 }
444499 }
445500}
@@ -455,3 +510,50 @@ private struct BroadcastPayload: Decodable {
455510 let `private` : Bool
456511 }
457512}
513+
514+ private final class ConcurrentSubscribeRealtimeClient : RealtimeClientProtocol , @unchecked Sendable {
515+ private let _pushedMessages = LockIsolated < [ RealtimeMessageV2 ] > ( [ ] )
516+ private let _status = LockIsolated < RealtimeClientStatus > ( . connected)
517+
518+ let options : RealtimeClientOptions
519+ let http : any HTTPClientType
520+ let broadcastURL = URL ( string: " https://localhost:54321/realtime/v1/api/broadcast " ) !
521+
522+ init ( ) {
523+ self . options = RealtimeClientOptions (
524+ headers: [ " apikey " : " test-key " ] ,
525+ timeoutInterval: 5.0
526+ )
527+ self . http = HTTPClientMock ( )
528+ }
529+
530+ var status : RealtimeClientStatus {
531+ _status. value
532+ }
533+
534+ var pushedMessages : [ RealtimeMessageV2 ] {
535+ _pushedMessages. value
536+ }
537+
538+ var joinPushCount : Int {
539+ pushedMessages. filter { $0. event == ChannelEvent . join } . count
540+ }
541+
542+ func connect( ) async {
543+ _status. setValue ( . connected)
544+ }
545+
546+ func push( _ message: RealtimeMessageV2 ) {
547+ _pushedMessages. withValue { $0. append ( message) }
548+ }
549+
550+ func _getAccessToken( ) async -> String ? {
551+ nil
552+ }
553+
554+ func makeRef( ) -> String {
555+ UUID ( ) . uuidString
556+ }
557+
558+ func _remove( _: any RealtimeChannelProtocol ) { }
559+ }
0 commit comments