Skip to content

Commit e0d47f7

Browse files
committed
Fix [p]unsubscribe from all
Motivation: The methods of unsubscribing from all channels / patterns were not working as expected as they need to be special-case handled. Modifications: - Change: `RedisPubSubHandler` to be special-case unsubscribe when no arguments are provided Result: Developers should now properly be able to unsubscribe from all channels / patterns with a single method call.
1 parent 42e8d4b commit e0d47f7

File tree

2 files changed

+151
-7
lines changed

2 files changed

+151
-7
lines changed

Sources/RediStack/ChannelHandlers/RedisPubSubHandler.swift

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,29 @@ extension RedisPubSubHandler {
150150

151151
private func handleUnsubscribeMessage(
152152
withSubscriptionKey subscriptionKey: String,
153-
reportedSubscriptionCount subscriptionCount: Int
153+
reportedSubscriptionCount subscriptionCount: Int,
154+
unsubscribeFromAllKey: String
154155
) {
155-
defer { self.pendingUnsubscribes.removeValue(forKey: subscriptionKey)?.succeed(subscriptionCount) }
156-
157156
guard let subscription = self.subscriptions.removeValue(forKey: subscriptionKey) else { return }
158157

159158
subscription.onUnsubscribe?(subscriptionKey, subscriptionCount)
160159
subscription.type.gauge.decrement()
160+
161+
switch self.pendingUnsubscribes.removeValue(forKey: subscriptionKey) {
162+
// we found a specific pattern/channel was being removed, so just fulfill the notification
163+
case let .some(promise):
164+
promise.succeed(subscriptionCount)
165+
166+
// if one wasn't found, this means a [p]unsubscribe all was issued
167+
case .none:
168+
// and we want to wait for the subscription count to be 0 before we resolve it's notification
169+
// this count may be from what Redis reports, or the count of subscriptions for this particular type
170+
guard
171+
subscriptionCount == 0 || self.subscriptions.count(where: { $0.type == subscription.type }) == 0
172+
else { return }
173+
// always report back the count according to Redis, it is the source of truth
174+
self.pendingUnsubscribes.removeValue(forKey: unsubscribeFromAllKey)?.succeed(subscriptionCount)
175+
}
161176
}
162177

163178
private func handleMessage(
@@ -249,6 +264,12 @@ extension RedisPubSubHandler {
249264

250265
// we send the UNSUBSCRIBE message to Redis,
251266
// and in the response we handle the actual removal of the receiver closure
267+
268+
// if there are no channels / patterns specified,
269+
// then this is a special case of unsubscribing from all patterns / channels
270+
guard !target.values.isEmpty else {
271+
return self.unsubscribeAll(for: target)
272+
}
252273

253274
return self.sendSubscriptionChange(
254275
subscriptionChangeKeyword: target.unsubscribeKeyword,
@@ -302,9 +323,21 @@ extension RedisPubSubHandler {
302323
return latestSubscriptionCount
303324
}
304325

305-
return self.context.writeAndFlush(self.wrapOutboundOut(.array(command)))
326+
return self.context
327+
.writeAndFlush(self.wrapOutboundOut(.array(command)))
306328
.flatMap { return subscriptionCountFuture }
307329
}
330+
331+
private func unsubscribeAll(for target: RedisSubscriptionTarget) -> EventLoopFuture<Int> {
332+
let command = [RESPValue(bulk: target.unsubscribeKeyword)]
333+
334+
let promise = self.context.eventLoop.makePromise(of: Int.self)
335+
self.pendingUnsubscribes.updateValue(promise, forKey: target.unsubscribeAllKey)
336+
337+
return self.context
338+
.writeAndFlush(self.wrapOutboundOut(.array(command)))
339+
.flatMap { promise.futureResult }
340+
}
308341
}
309342

310343
// MARK: ChannelHandler
@@ -376,8 +409,19 @@ extension RedisPubSubHandler: ChannelInboundHandler {
376409
case "subscribe", "psubscribe":
377410
self.handleSubscribeMessage(withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message.int!)
378411

379-
case "unsubscribe", "punsubscribe":
380-
self.handleUnsubscribeMessage(withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message.int!)
412+
case "unsubscribe":
413+
self.handleUnsubscribeMessage(
414+
withSubscriptionKey: channelOrPattern,
415+
reportedSubscriptionCount: message.int!,
416+
unsubscribeFromAllKey: kUnsubscribeAllChannelsKey
417+
)
418+
419+
case "punsubscribe":
420+
self.handleUnsubscribeMessage(
421+
withSubscriptionKey: channelOrPattern,
422+
reportedSubscriptionCount: message.int!,
423+
unsubscribeFromAllKey: kUnsubscribeAllPatternsKey
424+
)
381425

382426
// if we don't have a match, fire a channel read to forward to the next handler
383427
default: context.fireChannelRead(data)
@@ -419,6 +463,10 @@ extension RedisPubSubHandler: ChannelOutboundHandler {
419463

420464
// MARK: Private Types
421465

466+
// keys used for the pendingUnsubscribes
467+
private let kUnsubscribeAllChannelsKey = "__RS_ALL_CHS"
468+
private let kUnsubscribeAllPatternsKey = "__RS_ALL_PNS"
469+
422470
fileprivate enum SubscriptionType {
423471
case channel, pattern
424472

@@ -433,7 +481,7 @@ fileprivate enum SubscriptionType {
433481
extension RedisPubSubHandler {
434482
private typealias PendingSubscriptionChangeQueue = [String: EventLoopPromise<Int>]
435483

436-
private final class Subscription {
484+
fileprivate final class Subscription {
437485
let type: SubscriptionType
438486
let onMessage: RedisSubscriptionMessageReceiver
439487
var onSubscribe: RedisSubscriptionChangeHandler? // will be set to nil after first call
@@ -460,6 +508,13 @@ extension RedisPubSubHandler {
460508
// MARK: Subscription Management Helpers
461509

462510
extension RedisSubscriptionTarget {
511+
fileprivate var unsubscribeAllKey: String {
512+
switch self {
513+
case .channels: return kUnsubscribeAllChannelsKey
514+
case .patterns: return kUnsubscribeAllPatternsKey
515+
}
516+
}
517+
463518
fileprivate var subscriptionType: SubscriptionType {
464519
switch self {
465520
case .channels: return .channel
@@ -480,3 +535,12 @@ extension RedisSubscriptionTarget {
480535
}
481536
}
482537
}
538+
539+
extension Dictionary where Key == String, Value == RedisPubSubHandler.Subscription {
540+
func count(where isIncluded: (Value) -> Bool) -> Int {
541+
self.reduce(into: 0) {
542+
guard isIncluded($1.value) else { return }
543+
$0 += 1
544+
}
545+
}
546+
}

Tests/RediStackIntegrationTests/Commands/PubSubCommandsTests.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,86 @@ final class RedisPubSubCommandsTests: RediStackIntegrationTestCase {
108108
let value = try self.connection.send(command: "QUIT").wait()
109109
XCTAssertEqual(value.string, "OK")
110110
}
111+
112+
func test_unsubscribeFromAllChannels() throws {
113+
let subscriber = try self.makeNewConnection()
114+
defer { try? subscriber.close().wait() }
115+
116+
let channels = (1...5).map { RedisChannelName("\(#function)\($0)") }
117+
118+
let expectation = self.expectation(description: "all channel subscriptions should be cancelled")
119+
expectation.expectedFulfillmentCount = channels.count
120+
121+
try subscriber.subscribe(
122+
to: channels,
123+
messageReceiver: { _, _ in },
124+
onSubscribe: nil,
125+
onUnsubscribe: { _, _ in expectation.fulfill() }
126+
).wait()
127+
128+
XCTAssertTrue(subscriber.isSubscribed)
129+
try subscriber.unsubscribe().wait()
130+
XCTAssertFalse(subscriber.isSubscribed)
131+
132+
self.waitForExpectations(timeout: 1)
133+
}
134+
135+
func test_unsubscribeFromAllPatterns() throws {
136+
let subscriber = try self.makeNewConnection()
137+
defer { try? subscriber.close().wait() }
138+
139+
let patterns = (1...3).map { ("*\(#function)\($0)") }
140+
141+
let expectation = self.expectation(description: "all pattern subscriptions should be cancelled")
142+
expectation.expectedFulfillmentCount = patterns.count
143+
144+
try subscriber.psubscribe(
145+
to: patterns,
146+
messageReceiver: { _, _ in },
147+
onSubscribe: nil,
148+
onUnsubscribe: { _, _ in expectation.fulfill() }
149+
).wait()
150+
151+
XCTAssertTrue(subscriber.isSubscribed)
152+
try subscriber.punsubscribe().wait()
153+
XCTAssertFalse(subscriber.isSubscribed)
154+
155+
self.waitForExpectations(timeout: 1)
156+
}
157+
158+
func test_unsubscribeFromAllMixed() throws {
159+
let subscriber = try self.makeNewConnection()
160+
defer { try? subscriber.close().wait() }
161+
162+
let expectation = self.expectation(description: "both unsubscribes should be completed")
163+
expectation.expectedFulfillmentCount = 2
164+
165+
XCTAssertFalse(subscriber.isSubscribed)
166+
167+
try subscriber.subscribe(
168+
to: #function,
169+
messageReceiver: { _, _ in },
170+
onSubscribe: nil,
171+
onUnsubscribe: { _, _ in expectation.fulfill() }
172+
).wait()
173+
XCTAssertTrue(subscriber.isSubscribed)
174+
175+
try subscriber.psubscribe(
176+
to: "*\(#function)",
177+
messageReceiver: { _, _ in },
178+
onSubscribe: nil,
179+
onUnsubscribe: { _, _ in expectation.fulfill() }
180+
).wait()
181+
XCTAssertTrue(subscriber.isSubscribed)
182+
183+
try subscriber.unsubscribe().wait()
184+
XCTAssertTrue(subscriber.isSubscribed)
185+
186+
try subscriber.punsubscribe().wait()
187+
XCTAssertFalse(subscriber.isSubscribed)
188+
189+
self.waitForExpectations(timeout: 1)
190+
}
111191
}
112192

113193
final class RedisPubSubCommandsPoolTests: RediStackConnectionPoolIntegrationTestCase {

0 commit comments

Comments
 (0)