Skip to content

Commit 64547be

Browse files
authored
Don't allow SUBSCRIBE command to be cancelled (#214)
* Don't allow SUBSCRIBE command to be cancelled Signed-off-by: Adam Fowler <[email protected]> * Re-instate testCancelSubscribe Signed-off-by: Adam Fowler <[email protected]> --------- Signed-off-by: Adam Fowler <[email protected]>
1 parent 3c38688 commit 64547be

File tree

2 files changed

+30
-33
lines changed

2 files changed

+30
-33
lines changed

Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -204,37 +204,26 @@ extension ValkeyConnection {
204204
) async throws -> (Int, ValkeySubscription) {
205205
let requestID = Self.requestIDGenerator.next()
206206
let (stream, streamContinuation) = ValkeySubscription.makeStream()
207-
return try await withTaskCancellationHandler {
208-
if Task.isCancelled {
209-
throw ValkeyClientError(.cancelled)
210-
}
211-
let subscriptionID: Int = try await withCheckedThrowingContinuation { continuation in
212-
self.channelHandler.subscribe(
213-
command: command,
214-
streamContinuation: streamContinuation,
215-
filters: filters,
216-
promise: .swift(continuation),
217-
requestID: requestID
218-
)
219-
}
220-
return (subscriptionID, stream)
221-
} onCancel: {
222-
self.cancel(requestID: requestID)
207+
if Task.isCancelled {
208+
throw ValkeyClientError(.cancelled)
223209
}
210+
let subscriptionID: Int = try await withCheckedThrowingContinuation { continuation in
211+
self.channelHandler.subscribe(
212+
command: command,
213+
streamContinuation: streamContinuation,
214+
filters: filters,
215+
promise: .swift(continuation),
216+
requestID: requestID
217+
)
218+
}
219+
return (subscriptionID, stream)
224220
}
225221

226222
@usableFromInline
227223
func unsubscribe(id: Int) async throws {
228224
let requestID = Self.requestIDGenerator.next()
229-
try await withTaskCancellationHandler {
230-
if Task.isCancelled {
231-
throw ValkeyClientError(.cancelled)
232-
}
233-
try await withCheckedThrowingContinuation { continuation in
234-
self.channelHandler.unsubscribe(id: id, promise: .swift(continuation), requestID: requestID)
235-
}
236-
} onCancel: {
237-
self.cancel(requestID: requestID)
225+
try await withCheckedThrowingContinuation { continuation in
226+
self.channelHandler.unsubscribe(id: id, promise: .swift(continuation), requestID: requestID)
238227
}
239228
}
240229

Tests/ValkeyTests/ValkeySubscriptionTests.swift

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -683,17 +683,25 @@ struct SubscriptionTests {
683683

684684
try await withThrowingTaskGroup(of: Void.self) { group in
685685
group.addTask {
686-
await #expect(throws: ValkeyClientError(.cancelled)) {
687-
try await connection.subscribe(to: "test") { _ in }
686+
await #expect(throws: CancellationError.self) {
687+
try await connection.subscribe(to: "test") { subscription in
688+
for try await _ in subscription {}
689+
}
688690
}
689691
}
690-
group.addTask {
691-
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
692-
// expect SUBSCRIBE command
693-
#expect(outbound == RESPToken(.command(["SUBSCRIBE", "test"])).base)
694-
}
695-
try await group.next()
692+
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
693+
// expect SUBSCRIBE command
694+
#expect(outbound == RESPToken(.command(["SUBSCRIBE", "test"])).base)
696695
group.cancelAll()
696+
try await Task.sleep(for: .milliseconds(10))
697+
698+
// push subscribe
699+
try await channel.writeInbound(RESPToken(.push([.bulkString("subscribe"), .bulkString("test"), .number(1)])).base)
700+
// expect UNSUBSCRIBE command
701+
outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
702+
#expect(outbound == RESPToken(.command(["UNSUBSCRIBE", "test"])).base)
703+
// push unsubscribe
704+
try await channel.writeInbound(RESPToken(.push([.bulkString("unsubscribe"), .bulkString("test"), .number(0)])).base)
697705
}
698706
#expect(await connection.isSubscriptionsEmpty())
699707
}

0 commit comments

Comments
 (0)