@@ -150,14 +150,29 @@ extension RedisPubSubHandler {
150
150
151
151
private func handleUnsubscribeMessage(
152
152
withSubscriptionKey subscriptionKey: String ,
153
- reportedSubscriptionCount subscriptionCount: Int
153
+ reportedSubscriptionCount subscriptionCount: Int ,
154
+ unsubscribeFromAllKey: String
154
155
) {
155
- defer { self . pendingUnsubscribes. removeValue ( forKey: subscriptionKey) ? . succeed ( subscriptionCount) }
156
-
157
156
guard let subscription = self . subscriptions. removeValue ( forKey: subscriptionKey) else { return }
158
157
159
158
subscription. onUnsubscribe ? ( subscriptionKey, subscriptionCount)
160
159
subscription. type. gauge. decrement ( )
160
+
161
+ switch self . pendingUnsubscribes. removeValue ( forKey: subscriptionKey) {
162
+ // we found a specific pattern/channel was being removed, so just fulfill the notification
163
+ case let . some( promise) :
164
+ promise. succeed ( subscriptionCount)
165
+
166
+ // if one wasn't found, this means a [p]unsubscribe all was issued
167
+ case . none:
168
+ // and we want to wait for the subscription count to be 0 before we resolve it's notification
169
+ // this count may be from what Redis reports, or the count of subscriptions for this particular type
170
+ guard
171
+ subscriptionCount == 0 || self . subscriptions. count ( where: { $0. type == subscription. type } ) == 0
172
+ else { return }
173
+ // always report back the count according to Redis, it is the source of truth
174
+ self . pendingUnsubscribes. removeValue ( forKey: unsubscribeFromAllKey) ? . succeed ( subscriptionCount)
175
+ }
161
176
}
162
177
163
178
private func handleMessage(
@@ -249,6 +264,12 @@ extension RedisPubSubHandler {
249
264
250
265
// we send the UNSUBSCRIBE message to Redis,
251
266
// and in the response we handle the actual removal of the receiver closure
267
+
268
+ // if there are no channels / patterns specified,
269
+ // then this is a special case of unsubscribing from all patterns / channels
270
+ guard !target. values. isEmpty else {
271
+ return self . unsubscribeAll ( for: target)
272
+ }
252
273
253
274
return self . sendSubscriptionChange (
254
275
subscriptionChangeKeyword: target. unsubscribeKeyword,
@@ -302,9 +323,21 @@ extension RedisPubSubHandler {
302
323
return latestSubscriptionCount
303
324
}
304
325
305
- return self . context. writeAndFlush ( self . wrapOutboundOut ( . array( command) ) )
326
+ return self . context
327
+ . writeAndFlush ( self . wrapOutboundOut ( . array( command) ) )
306
328
. flatMap { return subscriptionCountFuture }
307
329
}
330
+
331
+ private func unsubscribeAll( for target: RedisSubscriptionTarget ) -> EventLoopFuture < Int > {
332
+ let command = [ RESPValue ( bulk: target. unsubscribeKeyword) ]
333
+
334
+ let promise = self . context. eventLoop. makePromise ( of: Int . self)
335
+ self . pendingUnsubscribes. updateValue ( promise, forKey: target. unsubscribeAllKey)
336
+
337
+ return self . context
338
+ . writeAndFlush ( self . wrapOutboundOut ( . array( command) ) )
339
+ . flatMap { promise. futureResult }
340
+ }
308
341
}
309
342
310
343
// MARK: ChannelHandler
@@ -376,8 +409,19 @@ extension RedisPubSubHandler: ChannelInboundHandler {
376
409
case " subscribe " , " psubscribe " :
377
410
self . handleSubscribeMessage ( withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message. int!)
378
411
379
- case " unsubscribe " , " punsubscribe " :
380
- self . handleUnsubscribeMessage ( withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message. int!)
412
+ case " unsubscribe " :
413
+ self . handleUnsubscribeMessage (
414
+ withSubscriptionKey: channelOrPattern,
415
+ reportedSubscriptionCount: message. int!,
416
+ unsubscribeFromAllKey: kUnsubscribeAllChannelsKey
417
+ )
418
+
419
+ case " punsubscribe " :
420
+ self . handleUnsubscribeMessage (
421
+ withSubscriptionKey: channelOrPattern,
422
+ reportedSubscriptionCount: message. int!,
423
+ unsubscribeFromAllKey: kUnsubscribeAllPatternsKey
424
+ )
381
425
382
426
// if we don't have a match, fire a channel read to forward to the next handler
383
427
default : context. fireChannelRead ( data)
@@ -419,6 +463,10 @@ extension RedisPubSubHandler: ChannelOutboundHandler {
419
463
420
464
// MARK: Private Types
421
465
466
+ // keys used for the pendingUnsubscribes
467
+ private let kUnsubscribeAllChannelsKey = " __RS_ALL_CHS "
468
+ private let kUnsubscribeAllPatternsKey = " __RS_ALL_PNS "
469
+
422
470
fileprivate enum SubscriptionType {
423
471
case channel, pattern
424
472
@@ -433,7 +481,7 @@ fileprivate enum SubscriptionType {
433
481
extension RedisPubSubHandler {
434
482
private typealias PendingSubscriptionChangeQueue = [ String : EventLoopPromise < Int > ]
435
483
436
- private final class Subscription {
484
+ fileprivate final class Subscription {
437
485
let type : SubscriptionType
438
486
let onMessage : RedisSubscriptionMessageReceiver
439
487
var onSubscribe : RedisSubscriptionChangeHandler ? // will be set to nil after first call
@@ -460,6 +508,13 @@ extension RedisPubSubHandler {
460
508
// MARK: Subscription Management Helpers
461
509
462
510
extension RedisSubscriptionTarget {
511
+ fileprivate var unsubscribeAllKey : String {
512
+ switch self {
513
+ case . channels: return kUnsubscribeAllChannelsKey
514
+ case . patterns: return kUnsubscribeAllPatternsKey
515
+ }
516
+ }
517
+
463
518
fileprivate var subscriptionType : SubscriptionType {
464
519
switch self {
465
520
case . channels: return . channel
@@ -480,3 +535,12 @@ extension RedisSubscriptionTarget {
480
535
}
481
536
}
482
537
}
538
+
539
+ extension Dictionary where Key == String , Value == RedisPubSubHandler . Subscription {
540
+ func count( where isIncluded: ( Value ) -> Bool ) -> Int {
541
+ self . reduce ( into: 0 ) {
542
+ guard isIncluded ( $1. value) else { return }
543
+ $0 += 1
544
+ }
545
+ }
546
+ }
0 commit comments