diff --git a/Sources/Valkey/Connection/ValkeyConnection.swift b/Sources/Valkey/Connection/ValkeyConnection.swift index c4cbbbcd..be74124f 100644 --- a/Sources/Valkey/Connection/ValkeyConnection.swift +++ b/Sources/Valkey/Connection/ValkeyConnection.swift @@ -25,7 +25,7 @@ import NIOTransportServices /// A single connection to a Valkey database. @available(valkeySwift 1.0, *) -public final actor ValkeyConnection: ValkeyClientProtocol, Sendable { +public final actor ValkeyConnection: ValkeyClientProtocol, Sendable, Identifiable { nonisolated public let unownedExecutor: UnownedSerialExecutor /// Request ID generator @@ -367,15 +367,13 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable { /// create a BSD sockets based bootstrap private static func createSocketsBootstrap(eventLoopGroup: EventLoopGroup) -> ClientBootstrap { ClientBootstrap(group: eventLoopGroup) - .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) } #if canImport(Network) /// create a NIOTransportServices bootstrap using Network.framework private static func createTSBootstrap(eventLoopGroup: EventLoopGroup, tlsOptions: NWProtocolTLS.Options?) -> NIOTSConnectionBootstrap? { guard - let bootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoopGroup)? - .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) + let bootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoopGroup) else { return nil } diff --git a/Sources/Valkey/Documentation.docc/Pubsub.md b/Sources/Valkey/Documentation.docc/Pubsub.md index 1094a8cd..928236f0 100644 --- a/Sources/Valkey/Documentation.docc/Pubsub.md +++ b/Sources/Valkey/Documentation.docc/Pubsub.md @@ -85,7 +85,7 @@ try await connection.clientTracking( #### Subscribing to Invalidation Events -Once tracking is enabled you can subscribe to invalidation events using ``ValkeyConnection/subscribeKeyInvalidations(process:)``. The AsyncSequence passed to the `process` closure is a list of keys that have been invalidated. +Once tracking is enabled you can subscribe to invalidation events using ``ValkeyConnection/subscribeKeyInvalidations(isolation:process:)``. The AsyncSequence passed to the `process` closure is a list of keys that have been invalidated. ```swift try await connection.subscribeKeyInvalidations { keys in diff --git a/Sources/Valkey/Node/ValkeyNodeClient.swift b/Sources/Valkey/Node/ValkeyNodeClient.swift index fa097f90..8adccfc5 100644 --- a/Sources/Valkey/Node/ValkeyNodeClient.swift +++ b/Sources/Valkey/Node/ValkeyNodeClient.swift @@ -123,9 +123,15 @@ extension ValkeyNodeClient { return try await operation(connection) } - private func leaseConnection() async throws -> ValkeyConnection { + @usableFromInline + func leaseConnection() async throws -> ValkeyConnection { try await self.connectionPool.leaseConnection() } + + @usableFromInline + func releaseConnection(_ connection: ValkeyConnection) { + self.connectionPool.releaseConnection(connection) + } } /// Extend ValkeyNode so we can call commands directly from it diff --git a/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift b/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift new file mode 100644 index 00000000..e17278ff --- /dev/null +++ b/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift @@ -0,0 +1,193 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the valkey-swift open source project +// +// Copyright (c) 2025 the valkey-swift project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of valkey-swift project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import Synchronization + +@available(valkeySwift 1.0, *) +extension ValkeyClient { + @inlinable + func withSubscriptionConnection( + isolation: isolated (any Actor)? = #isolation, + operation: (ValkeyConnection) async throws -> sending Value + ) async throws -> sending Value { + try await self.subscriptionConnection.withValue { + try await operation($0) + } acquire: { + try await self.node.leaseConnection() + } release: { + self.node.releaseConnection($0) + } + } + + /// Subscribe to list of channels and run closure with subscription + /// + /// When the closure is exited the channels are automatically unsubscribed from. It is + /// possible to have multiple subscriptions running on the same connection and unsubscribe + /// commands will only be sent to Valkey when there are no subscriptions active for that + /// channel + /// + /// - Parameters: + /// - channels: list of channels to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + @inlinable + public func subscribe( + to channels: String..., + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> Value { + try await self.subscribe(to: channels, process: process) + } + + @inlinable + /// Subscribe to list of channels and run closure with subscription + /// + /// When the closure is exited the channels are automatically unsubscribed from. It is + /// possible to have multiple subscriptions running on the same connection and unsubscribe + /// commands will only be sent to Valkey when there are no subscriptions active for that + /// channel + /// + /// - Parameters: + /// - channels: list of channels to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + public func subscribe( + to channels: [String], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> Value { + try await self.subscribe( + command: SUBSCRIBE(channels: channels), + filters: channels.map { .channel($0) }, + process: process + ) + } + + /// Subscribe to list of channel patterns and run closure with subscription + /// + /// When the closure is exited the patterns are automatically unsubscribed from. It is + /// possible to have multiple subscriptions running on the same connection and unsubscribe + /// commands will only be sent to Valkey when there are no subscriptions active for that + /// pattern + /// + /// - Parameters: + /// - patterns: list of channel patterns to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + @inlinable + public func psubscribe( + to patterns: String..., + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> Value { + try await self.psubscribe(to: patterns, process: process) + } + + /// Subscribe to list of pattern matching channels and run closure with subscription + /// + /// When the closure is exited the patterns are automatically unsubscribed from. It is + /// possible to have multiple subscriptions running on the same connection and unsubscribe + /// commands will only be sent to Valkey when there are no subscriptions active for that + /// pattern + /// + /// - Parameters: + /// - patterns: list of channel patterns to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + @inlinable + public func psubscribe( + to patterns: [String], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> Value { + try await self.subscribe( + command: PSUBSCRIBE(patterns: patterns), + filters: patterns.map { .pattern($0) }, + process: process + ) + } + + /// Subscribe to key invalidation channel required for client-side caching + /// + /// See https://valkey.io/topics/client-side-caching/ for more details + /// + /// When the closure is exited the channel is automatically unsubscribed from. It is + /// possible to have multiple subscriptions running on the same connection and unsubscribe + /// commands will only be sent to Valkey when there are no subscriptions active for that + /// channel + /// + /// - Parameters: + /// - isolation: Actor isolation + /// - process: Closure that is called with async sequence of key invalidations + /// - Returns: Return value of closure + @inlinable + public func subscribeKeyInvalidations( + isolation: isolated (any Actor)? = #isolation, + process: (AsyncMapSequence) async throws -> sending Value + ) async throws -> Value { + try await self.subscribe(to: [ValkeySubscriptions.invalidateChannel]) { subscription in + let keys = subscription.map { ValkeyKey($0.message) } + return try await process(keys) + } + } + + @inlinable + func subscribe( + command: some ValkeyCommand, + filters: [ValkeySubscriptionFilter], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> Value { + try await withThrowingTaskGroup(of: Void.self, isolation: isolation) { group in + let (stream, cont) = ValkeySubscription.makeStream() + group.addTask { + while true { + do { + try Task.checkCancellation() + return try await self.withSubscriptionConnection { connection in + try await connection.subscribe(command: command, filters: filters) { subscription in + // push messages on connection subscription to client subscription + for try await message in subscription { + cont.yield(message) + } + } + cont.finish() + } + } catch let error as ValkeyClientError { + // if connection closes for some reason don't exit loop so it opens a new connection + switch error.errorCode { + case .connectionClosed, .connectionClosedDueToCancellation, .connectionClosing: + self.subscriptionConnection.reset() + break + default: + cont.finish(throwing: error) + return + } + } catch { + cont.finish(throwing: error) + return + } + } + } + let value = try await process(stream) + group.cancelAll() + return value + } + } +} diff --git a/Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift b/Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift index 4ec0de54..d3446c12 100644 --- a/Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift +++ b/Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift @@ -55,18 +55,12 @@ extension ValkeyConnection { isolation: isolated (any Actor)? = #isolation, process: (ValkeySubscription) async throws -> sending Value ) async throws -> sending Value { - let command = SUBSCRIBE(channels: channels) - let (id, stream) = try await subscribe(command: command, filters: channels.map { .channel($0) }) - let value: Value - do { - value = try await process(stream) - try Task.checkCancellation() - } catch { - _ = try? await unsubscribe(id: id) - throw error - } - _ = try await unsubscribe(id: id) - return value + try await self.subscribe( + command: SUBSCRIBE(channels: channels), + filters: channels.map { .channel($0) }, + isolation: isolation, + process: process + ) } /// Subscribe to list of channel patterns and run closure with subscription @@ -108,18 +102,12 @@ extension ValkeyConnection { isolation: isolated (any Actor)? = #isolation, process: (ValkeySubscription) async throws -> sending Value ) async throws -> sending Value { - let command = PSUBSCRIBE(patterns: patterns) - let (id, stream) = try await subscribe(command: command, filters: patterns.map { .pattern($0) }) - let value: Value - do { - value = try await process(stream) - try Task.checkCancellation() - } catch { - _ = try? await unsubscribe(id: id) - throw error - } - _ = try await unsubscribe(id: id) - return value + try await self.subscribe( + command: PSUBSCRIBE(patterns: patterns), + filters: patterns.map { .pattern($0) }, + isolation: isolation, + process: process + ) } /// Subscribe to list of shard channels and run closure with subscription @@ -130,17 +118,17 @@ extension ValkeyConnection { /// pattern /// /// - Parameters: - /// - shardchannel: list of shard channels to subscribe to + /// - shardchannels: list of shard channels to subscribe to /// - isolation: Actor isolation /// - process: Closure that is called with subscription async sequence /// - Returns: Return value of closure @inlinable public func ssubscribe( - to shardchannel: String..., + to shardchannels: String..., isolation: isolated (any Actor)? = #isolation, process: (ValkeySubscription) async throws -> sending Value ) async throws -> sending Value { - try await self.ssubscribe(to: shardchannel, process: process) + try await self.ssubscribe(to: shardchannels, process: process) } /// Subscribe to list of shard channels and run closure with subscription @@ -151,28 +139,22 @@ extension ValkeyConnection { /// pattern /// /// - Parameters: - /// - shardchannel: list of shard channels to subscribe to + /// - shardchannels: list of shard channels to subscribe to /// - isolation: Actor isolation /// - process: Closure that is called with subscription async sequence /// - Returns: Return value of closure @inlinable public func ssubscribe( - to shardchannel: [String], + to shardchannels: [String], isolation: isolated (any Actor)? = #isolation, process: (ValkeySubscription) async throws -> sending Value ) async throws -> sending Value { - let command = SSUBSCRIBE(shardchannels: shardchannel) - let (id, stream) = try await subscribe(command: command, filters: shardchannel.map { .shardChannel($0) }) - let value: Value - do { - value = try await process(stream) - try Task.checkCancellation() - } catch { - _ = try? await unsubscribe(id: id) - throw error - } - _ = try await unsubscribe(id: id) - return value + try await self.subscribe( + command: SSUBSCRIBE(shardchannels: shardchannels), + filters: shardchannels.map { .shardChannel($0) }, + isolation: isolation, + process: process + ) } /// Subscribe to key invalidation channel required for client-side caching @@ -185,18 +167,43 @@ extension ValkeyConnection { /// channel /// /// - Parameters: + /// - isolation: Actor isolation /// - process: Closure that is called with async sequence of key invalidations /// - Returns: Return value of closure @inlinable public func subscribeKeyInvalidations( + isolation: isolated (any Actor)? = #isolation, process: (AsyncMapSequence) async throws -> sending Value ) async throws -> sending Value { - try await self.subscribe(to: [ValkeySubscriptions.invalidateChannel]) { subscription in + try await self.subscribe(to: [ValkeySubscriptions.invalidateChannel], isolation: isolation) { subscription in let keys = subscription.map { ValkeyKey($0.message) } return try await process(keys) } } + @inlinable + func subscribe( + command: some ValkeyCommand, + filters: [ValkeySubscriptionFilter], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> sending Value { + let (id, stream) = try await subscribe(command: command, filters: filters) + let value: Value + do { + value = try await process(stream) + try Task.checkCancellation() + } catch { + // call unsubscrobe to avoid it being cancelled + _ = await Task { + try await unsubscribe(id: id) + }.result + throw error + } + _ = try await unsubscribe(id: id) + return value + } + @usableFromInline func subscribe( command: some ValkeyCommand, diff --git a/Sources/Valkey/Subscriptions/ValkeySubscription.swift b/Sources/Valkey/Subscriptions/ValkeySubscription.swift index ac72591e..34e40661 100644 --- a/Sources/Valkey/Subscriptions/ValkeySubscription.swift +++ b/Sources/Valkey/Subscriptions/ValkeySubscription.swift @@ -38,11 +38,14 @@ public struct ValkeySubscription: AsyncSequence, Sendable { /// The type that the sequence produces. public typealias Element = ValkeySubscriptionMessage + @usableFromInline typealias BaseAsyncSequence = AsyncThrowingStream + @usableFromInline typealias Continuation = BaseAsyncSequence.Continuation let base: BaseAsyncSequence + @usableFromInline static func makeStream() -> (Self, Self.Continuation) { let (stream, continuation) = BaseAsyncSequence.makeStream() return (.init(base: stream), continuation) diff --git a/Sources/Valkey/Utils/AsyncInitializedReference.swift b/Sources/Valkey/Utils/AsyncInitializedReference.swift new file mode 100644 index 00000000..4f41caf2 --- /dev/null +++ b/Sources/Valkey/Utils/AsyncInitializedReference.swift @@ -0,0 +1,142 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the valkey-swift project +// +// Copyright (c) 2025 the valkey-swift authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See valkey-swift/CONTRIBUTORS.txt for the list of valkey-swift authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Synchronization + +/// Stores a reference to a single instance of a type that is initialized asynchronously +/// +/// It ensures only ever one version of the object is initialized, even if it requested +/// twice during initialization. Once it is available the object includes a reference count +/// so we can clean it up once nobody references it. +@available(valkeySwift 1.0, *) +@usableFromInline +struct AsyncInitializedReferencedObject: ~Copyable, Sendable { + @usableFromInline + enum Action { + case use(Value) + case acquire + } + + @usableFromInline + enum State { + case uninitialized + case acquiring([CheckedContinuation]) + case available(Value, Int) + } + @usableFromInline + let state: Mutex + + init() { + self.state = .init(.uninitialized) + } + + @inlinable + func acquire(isolation: isolated (any Actor)? = #isolation, _ operation: () async throws -> Value) async throws -> Value { + let action: Action = try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + self.state.withLock { state in + switch state { + case .uninitialized: + state = .acquiring([]) + cont.resume(returning: .acquire) + case .acquiring(var continuations): + continuations.append(cont) + state = .acquiring(continuations) + case .available(let connection, let count): + state = .available(connection, count + 1) + cont.resume(returning: .use(connection)) + } + } + } + switch action { + case .acquire: + do { + let connection = try await operation() + return self.state.withLock { state in + guard case .acquiring(let continuations) = state else { + preconditionFailure("State should still be acquiring") + } + for cont in continuations { + cont.resume(returning: .use(connection)) + } + state = .available(connection, continuations.count + 1) + return connection + } + } catch is CancellationError { + self.state.withLock { state in + guard case .acquiring(var continuations) = state else { + preconditionFailure("Can't have state set to none, while acquiring connection") + } + if let lastContinuation = continuations.popLast() { + state = .acquiring(continuations) + lastContinuation.resume(returning: .acquire) + } else { + state = .uninitialized + } + } + throw CancellationError() + } catch { + return try self.state.withLock { state in + guard case .acquiring(let continuations) = state else { + preconditionFailure("Can't have state set to none, while acquiring connection") + } + for cont in continuations { + cont.resume(throwing: error) + } + state = .uninitialized + throw error + } + } + case .use(let connection): + return connection + } + } + + @inlinable + func release(id: Value.ID, _ operation: (Value) -> Void) { + self.state.withLock { state in + switch state { + case .uninitialized, .acquiring: + break + case .available(let connection, let count): + guard connection.id == id else { return } + assert(count > 0, "Cannot have a count of active references to connection less than one") + if count == 1 { + state = .uninitialized + operation(connection) + } else { + state = .available(connection, count - 1) + } + } + } + } + + @inlinable + func withValue( + isolation: isolated (any Actor)? = #isolation, + _ operation: (Value) async throws -> sending Returning, + acquire acquireOperation: () async throws -> Value, + release releaseOperation: (Value) -> Void + ) async throws -> sending Returning { + let value = try await self.acquire(acquireOperation) + defer { + self.release(id: value.id, releaseOperation) + } + return try await operation(value) + } + + @usableFromInline + func reset() { + self.state.withLock { $0 = .uninitialized } + } +} diff --git a/Sources/Valkey/ValkeyClient.swift b/Sources/Valkey/ValkeyClient.swift index c1b73159..3d753856 100644 --- a/Sources/Valkey/ValkeyClient.swift +++ b/Sources/Valkey/ValkeyClient.swift @@ -41,6 +41,9 @@ public final class ValkeyClient: Sendable { let logger: Logger /// running atomic let runningAtomic: Atomic + /// subscription state + @usableFromInline + let subscriptionConnection: AsyncInitializedReferencedObject /// Creates a new Valkey client /// @@ -84,6 +87,7 @@ public final class ValkeyClient: Sendable { self.logger = logger self.runningAtomic = .init(false) self.node = self.nodeClientFactory.makeConnectionPool(serverAddress: address) + self.subscriptionConnection = .init() } } diff --git a/Tests/IntegrationTests/ValkeyTests.swift b/Tests/IntegrationTests/ValkeyTests.swift index 533efa46..a541f142 100644 --- a/Tests/IntegrationTests/ValkeyTests.swift +++ b/Tests/IntegrationTests/ValkeyTests.swift @@ -488,6 +488,119 @@ struct GeneratedCommands { } } + @Test + @available(valkeySwift 1.0, *) + func testCancelSubscription() async throws { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + var logger = Logger(label: "Subscriptions") + logger.logLevel = .trace + try await withValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger) { client in + await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await client.withConnection { connection in + try await connection.subscribe(to: "testCancelSubscriptions") { subscription in + cont.finish() + for try await _ in subscription { + } + } + #expect(await connection.isSubscriptionsEmpty()) + } + } + await stream.first { _ in true } + group.cancelAll() + } + } + } + + @Test + @available(valkeySwift 1.0, *) + func testClientSubscriptions() async throws { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + var logger = Logger(label: "Subscriptions") + logger.logLevel = .trace + try await withValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger) { client in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await client.subscribe(to: "testSubscriptions") { subscription in + cont.finish() + var iterator = subscription.makeAsyncIterator() + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "hello" } + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "goodbye" } + } + } + try await client.withConnection { connection in + await stream.first { _ in true } + _ = try await connection.publish(channel: "testSubscriptions", message: "hello") + _ = try await connection.publish(channel: "testSubscriptions", message: "goodbye") + } + try await group.waitForAll() + } + } + } + + @Test + @available(valkeySwift 1.0, *) + func testClientMultipleSubscriptions() async throws { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + var logger = Logger(label: "Subscriptions") + logger.logLevel = .trace + try await withValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger) { client in + try await withThrowingTaskGroup(of: Void.self) { group in + let count = 50 + for i in 0..: Sendable { + let value: Value + + init(_ value: consuming Value) { + self.value = value + } + } + + @Test + @available(valkeySwift 1.0, *) + func testReferenceCount() async throws { + let referencedObject = Box(AsyncInitializedReferencedObject()) + let test = try await referencedObject.value.acquire { + return Test() + } + let test2 = try await referencedObject.value.acquire { + return Test() + } + // verify we get the same object twice + #expect(test.id == test2.id) + // Verify we have two references + referencedObject.value.state.withLock { state in + switch state { + case .available(_, let count): + #expect(count == 2) + case .acquiring, .uninitialized: + Issue.record("Should have a connection") + } + } + } + + @Test + @available(valkeySwift 1.0, *) + func testReleaseGetsCalledOnce() async throws { + let referencedObject = Box(AsyncInitializedReferencedObject()) + let test = try await referencedObject.value.acquire { + return Test() + } + let test2 = try await referencedObject.value.acquire { + return Test() + } + let called = Atomic(0) + referencedObject.value.release(id: test.id) { _ in + called.add(1, ordering: .relaxed) + } + referencedObject.value.release(id: test2.id) { _ in + called.add(2, ordering: .relaxed) + } + #expect(called.load(ordering: .relaxed) == 2) + } + + @Test + @available(valkeySwift 1.0, *) + func testMultipleConcurrentAcquire() async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + let referencedObject = Box(AsyncInitializedReferencedObject()) + for _ in 0..<500 { + group.addTask { + let test = try await referencedObject.value.acquire { + try await Task.sleep(for: .milliseconds(50)) + return Test() + } + let test2 = try await referencedObject.value.acquire { + try await Task.sleep(for: .milliseconds(50)) + return Test() + } + #expect(test.id == test2.id) + + referencedObject.value.release(id: test.id) { _ in } + referencedObject.value.release(id: test2.id) { _ in } + } + } + for _ in 0..<100 { + group.addTask { + let test = try await referencedObject.value.acquire { + try await Task.sleep(for: .milliseconds(50)) + return Test() + } + referencedObject.value.release(id: test.id) { _ in } + } + } + try await group.waitForAll() + referencedObject.value.state.withLock { state in + if case .uninitialized = state { + } else { + Issue.record("Subscription channel should have been relesed") + } + } + } + } + @Test + @available(valkeySwift 1.0, *) + func testCancellationWhileAcquiring() async throws { + let referencedObject = Box(AsyncInitializedReferencedObject()) + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream1, cont1) = AsyncStream.makeStream(of: Void.self) + let (stream2, cont2) = AsyncStream.makeStream(of: Void.self) + // Run acquire three times, with the first one throwing a cancellation error. Use + // AsyncStream to ensure all acquires are active at the same time + group.addTask { + _ = try await referencedObject.value.acquire { + cont1.finish() + await stream2.first { _ in true } + throw CancellationError() + } + } + await stream1.first { _ in true } + group.addTask { + _ = try await referencedObject.value.acquire { + try await Task.sleep(for: .milliseconds(50)) + return Test() + } + } + group.addTask { + _ = try await referencedObject.value.acquire { + try await Task.sleep(for: .milliseconds(50)) + return Test() + } + } + try await Task.sleep(for: .milliseconds(50)) + cont2.finish() + } + // Verify we have two connections + let value: Test = try #require( + referencedObject.value.state.withLock { state in + switch state { + case .available(let value, let count): + #expect(count == 2) + return value + case .acquiring, .uninitialized: + Issue.record("Should have a connection") + return nil + } + } + ) + // Verify once we run release twice we have no connection + referencedObject.value.release(id: value.id) { _ in } + referencedObject.value.release(id: value.id) { _ in } + referencedObject.value.state.withLock { state in + switch state { + case .uninitialized: + break + case .acquiring, .available: + Issue.record("Should have a connection") + } + } + } + + @Test + @available(valkeySwift 1.0, *) + func testWithValue() async throws { + let referencedObject = Box(AsyncInitializedReferencedObject()) + let operationCalledCount = Box(Atomic(0)) + let acquireCalledCount = Box(Atomic(0)) + let releaseCalledCount = Box(Atomic(0)) + try await withThrowingTaskGroup(of: Void.self) { group in + for _ in 0..<3 { + group.addTask { + try await referencedObject.value.withValue { _ in + try await Task.sleep(for: .milliseconds(20)) + operationCalledCount.value.add(1, ordering: .relaxed) + } acquire: { + try await Task.sleep(for: .milliseconds(20)) + acquireCalledCount.value.add(1, ordering: .relaxed) + return Test() + } release: { _ in + releaseCalledCount.value.add(1, ordering: .relaxed) + } + } + } + try await group.waitForAll() + } + #expect(operationCalledCount.value.load(ordering: .relaxed) == 3) + #expect(acquireCalledCount.value.load(ordering: .relaxed) == 1) + #expect(releaseCalledCount.value.load(ordering: .relaxed) == 1) + } +} diff --git a/Tests/ValkeyTests/ValkeySubscriptionTests.swift b/Tests/ValkeyTests/ValkeySubscriptionTests.swift index e66d6e95..32c45cdf 100644 --- a/Tests/ValkeyTests/ValkeySubscriptionTests.swift +++ b/Tests/ValkeyTests/ValkeySubscriptionTests.swift @@ -15,6 +15,7 @@ import Logging import NIOCore import NIOEmbedded +import Synchronization import Testing @testable import Valkey @@ -729,9 +730,16 @@ struct SubscriptionTests { try await channel.writeInbound(RESPToken(.push([.bulkString("subscribe"), .bulkString("test"), .number(1)])).base) // push message try await channel.writeInbound(RESPToken(.push([.bulkString("message"), .bulkString("test"), .bulkString("Testing!")])).base) + } try await group.next() group.cancelAll() + + // respond to unsubscribe after cancellation + let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self) + #expect(outbound == RESPToken(.command(["UNSUBSCRIBE", "test"])).base) + // push unsubcribe + try await channel.writeInbound(RESPToken(.push([.bulkString("unsubscribe"), .bulkString("test"), .number(1)])).base) } #expect(await connection.isSubscriptionsEmpty()) }