Skip to content

Commit 61cc879

Browse files
committed
Change RedisConnection to end subscriptions when not allowed
Motivation: When `RedisConnection.allowSubscriptions` is set to `false`, the connection could still be in a subscription state leaving further commands to fail slowly from a full roundtrip to Redis, rather than succeeding as expected. This changes the implementation so that it triggers a full unsubscribe from patterns and channels when set to `false`. Modifications: - Change: `RedisConnection.allowSubscriptions` to call `unsubscribe()` and `punsubscribe()` when set to `false` - Change: `RedisPubSubHandler` to prefix storage of all dictionary keys to avoid name clashes between pattern and channel subscriptions Result: Developers should now have more deterministic and unsurprising behavior with PubSub in regards to subscription management and connection state.
1 parent e0d47f7 commit 61cc879

File tree

3 files changed

+102
-50
lines changed

3 files changed

+102
-50
lines changed

Sources/RediStack/ChannelHandlers/RedisPubSubHandler.swift

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ public enum RedisSubscriptionTarget: Equatable, CustomDebugStringConvertible {
101101
public final class RedisPubSubHandler {
102102
private var state: State = .default
103103

104+
// each key in the following maps _must_ be prefixed as there can be clashes between patterns and channel names
105+
104106
/// A map of channel names or patterns and their respective event registration.
105107
private var subscriptions: [String: Subscription]
106108
/// A queue of subscribe changes awaiting notification of completion.
@@ -135,30 +137,35 @@ public final class RedisPubSubHandler {
135137
extension RedisPubSubHandler {
136138
private func handleSubscribeMessage(
137139
withSubscriptionKey subscriptionKey: String,
138-
reportedSubscriptionCount subscriptionCount: Int
140+
reportedSubscriptionCount subscriptionCount: Int,
141+
keyPrefix: String
139142
) {
140-
defer { self.pendingSubscribes.removeValue(forKey: subscriptionKey)?.succeed(subscriptionCount) }
143+
let prefixedKey = self.prefixKey(subscriptionKey, with: keyPrefix)
144+
145+
defer { self.pendingSubscribes.removeValue(forKey: prefixedKey)?.succeed(subscriptionCount) }
141146

142-
guard let subscription = self.subscriptions[subscriptionKey] else { return }
147+
guard let subscription = self.subscriptions[prefixedKey] else { return }
143148

144149
subscription.onSubscribe?(subscriptionKey, subscriptionCount)
145150
subscription.onSubscribe = nil // nil to free memory
146-
self.subscriptions[subscriptionKey] = subscription
151+
self.subscriptions[prefixedKey] = subscription
147152

148153
subscription.type.gauge.increment()
149154
}
150155

151156
private func handleUnsubscribeMessage(
152157
withSubscriptionKey subscriptionKey: String,
153158
reportedSubscriptionCount subscriptionCount: Int,
154-
unsubscribeFromAllKey: String
159+
unsubscribeFromAllKey: String,
160+
keyPrefix: String
155161
) {
156-
guard let subscription = self.subscriptions.removeValue(forKey: subscriptionKey) else { return }
162+
let prefixedKey = self.prefixKey(subscriptionKey, with: keyPrefix)
163+
guard let subscription = self.subscriptions.removeValue(forKey: prefixedKey) else { return }
157164

158165
subscription.onUnsubscribe?(subscriptionKey, subscriptionCount)
159166
subscription.type.gauge.decrement()
160167

161-
switch self.pendingUnsubscribes.removeValue(forKey: subscriptionKey) {
168+
switch self.pendingUnsubscribes.removeValue(forKey: prefixedKey) {
162169
// we found a specific pattern/channel was being removed, so just fulfill the notification
163170
case let .some(promise):
164171
promise.succeed(subscriptionCount)
@@ -178,9 +185,10 @@ extension RedisPubSubHandler {
178185
private func handleMessage(
179186
_ message: RESPValue,
180187
from channel: RedisChannelName,
181-
withSubscriptionKey subscriptionKey: String
188+
withSubscriptionKey subscriptionKey: String,
189+
keyPrefix: String
182190
) {
183-
guard let subscription = self.subscriptions[subscriptionKey] else { return }
191+
guard let subscription = self.subscriptions[self.prefixKey(subscriptionKey, with: keyPrefix)] else { return }
184192
subscription.onMessage(channel, message)
185193
RedisMetrics.subscriptionMessagesReceivedCount.increment()
186194
}
@@ -232,7 +240,8 @@ extension RedisPubSubHandler {
232240
subscribeHandler: subscribeHandler,
233241
unsubscribeHandler: unsubscribeHandler
234242
)
235-
guard self.subscriptions.updateValue(subscription, forKey: targetKey) == nil else { return nil }
243+
let prefixedKey = self.prefixKey(targetKey, with: target.keyPrefix)
244+
guard self.subscriptions.updateValue(subscription, forKey: prefixedKey) == nil else { return nil }
236245
return targetKey
237246
}
238247

@@ -245,7 +254,8 @@ extension RedisPubSubHandler {
245254
return self.sendSubscriptionChange(
246255
subscriptionChangeKeyword: target.subscribeKeyword,
247256
subscriptionTargets: newSubscriptionTargets,
248-
queue: \.pendingSubscribes
257+
queue: \.pendingSubscribes,
258+
keyPrefix: target.keyPrefix
249259
)
250260
}
251261
}
@@ -274,14 +284,16 @@ extension RedisPubSubHandler {
274284
return self.sendSubscriptionChange(
275285
subscriptionChangeKeyword: target.unsubscribeKeyword,
276286
subscriptionTargets: target.values,
277-
queue: \.pendingUnsubscribes
287+
queue: \.pendingUnsubscribes,
288+
keyPrefix: target.keyPrefix
278289
)
279290
}
280291

281292
private func sendSubscriptionChange(
282293
subscriptionChangeKeyword keyword: String,
283294
subscriptionTargets targets: [String],
284-
queue pendingQueue: ReferenceWritableKeyPath<RedisPubSubHandler, PendingSubscriptionChangeQueue>
295+
queue pendingQueue: ReferenceWritableKeyPath<RedisPubSubHandler, PendingSubscriptionChangeQueue>,
296+
keyPrefix: String
285297
) -> EventLoopFuture<Int> {
286298
self.eventLoop.assertInEventLoop()
287299

@@ -298,7 +310,7 @@ extension RedisPubSubHandler {
298310

299311
// create them
300312
let pendingSubscriptions: [(String, EventLoopPromise<Int>)] = targets.map {
301-
return ($0, self.eventLoop.makePromise())
313+
return (self.prefixKey($0, with: keyPrefix), self.eventLoop.makePromise())
302314
}
303315
// add the subscription change handler to the appropriate queue for each individual subscription target
304316
pendingSubscriptions.forEach { self[keyPath: pendingQueue].updateValue($1, forKey: $0) }
@@ -399,28 +411,53 @@ extension RedisPubSubHandler: ChannelInboundHandler {
399411
// if we have a match, we're definitely in a pubsub message and we should handle it
400412

401413
switch messageKeyword {
402-
case "message": self.handleMessage(message, from: .init(channelOrPattern), withSubscriptionKey: channelOrPattern)
414+
case "message":
415+
self.handleMessage(
416+
message,
417+
from: .init(channelOrPattern),
418+
withSubscriptionKey: channelOrPattern,
419+
keyPrefix: kSubscriptionKeyPrefixChannel
420+
)
403421

404-
// the channel name is stored as the 3rd element in the array in 'pmessage' streams
405-
case "pmessage": self.handleMessage(message, from: .init(array[2].string!), withSubscriptionKey: channelOrPattern)
422+
423+
case "pmessage":
424+
self.handleMessage(
425+
message,
426+
from: .init(array[2].string!), // the channel name is stored as the 3rd element in the array in 'pmessage' streams
427+
withSubscriptionKey: channelOrPattern,
428+
keyPrefix: kSubscriptionKeyPrefixPattern
429+
)
406430

407431
// if the message keyword is for subscribing or unsubscribing,
408432
// the message is guaranteed to be the count of subscriptions the connection still has
409-
case "subscribe", "psubscribe":
410-
self.handleSubscribeMessage(withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message.int!)
433+
case "subscribe":
434+
self.handleSubscribeMessage(
435+
withSubscriptionKey: channelOrPattern,
436+
reportedSubscriptionCount: message.int!,
437+
keyPrefix: kSubscriptionKeyPrefixChannel
438+
)
439+
440+
case "psubscribe":
441+
self.handleSubscribeMessage(
442+
withSubscriptionKey: channelOrPattern,
443+
reportedSubscriptionCount: message.int!,
444+
keyPrefix: kSubscriptionKeyPrefixPattern
445+
)
411446

412447
case "unsubscribe":
413448
self.handleUnsubscribeMessage(
414449
withSubscriptionKey: channelOrPattern,
415450
reportedSubscriptionCount: message.int!,
416-
unsubscribeFromAllKey: kUnsubscribeAllChannelsKey
451+
unsubscribeFromAllKey: kUnsubscribeAllChannelsKey,
452+
keyPrefix: kSubscriptionKeyPrefixChannel
417453
)
418454

419455
case "punsubscribe":
420456
self.handleUnsubscribeMessage(
421457
withSubscriptionKey: channelOrPattern,
422458
reportedSubscriptionCount: message.int!,
423-
unsubscribeFromAllKey: kUnsubscribeAllPatternsKey
459+
unsubscribeFromAllKey: kUnsubscribeAllPatternsKey,
460+
keyPrefix: kSubscriptionKeyPrefixPattern
424461
)
425462

426463
// if we don't have a match, fire a channel read to forward to the next handler
@@ -507,6 +544,13 @@ extension RedisPubSubHandler {
507544

508545
// MARK: Subscription Management Helpers
509546

547+
private let kSubscriptionKeyPrefixChannel = "__RS_CS"
548+
private let kSubscriptionKeyPrefixPattern = "__RS_PS"
549+
550+
extension RedisPubSubHandler {
551+
private func prefixKey(_ key: String, with prefix: String) -> String { "\(prefix)_\(key)" }
552+
}
553+
510554
extension RedisSubscriptionTarget {
511555
fileprivate var unsubscribeAllKey: String {
512556
switch self {
@@ -515,6 +559,13 @@ extension RedisSubscriptionTarget {
515559
}
516560
}
517561

562+
fileprivate var keyPrefix: String {
563+
switch self {
564+
case .channels: return kSubscriptionKeyPrefixChannel
565+
case .patterns: return kSubscriptionKeyPrefixPattern
566+
}
567+
}
568+
518569
fileprivate var subscriptionType: SubscriptionType {
519570
switch self {
520571
case .channels: return .channel

Sources/RediStack/RedisConnection.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,12 @@ public final class RedisConnection: RedisClient, RedisClientWithUserContext {
133133
get { self.allowPubSub.load() }
134134
set(newValue) {
135135
self.allowPubSub.store(newValue)
136-
// TODO: Re-enable after [p]unsubscribe from all is fixed
137-
// guard self.isConnected else { return }
138-
// _ = EventLoopFuture<Void>.whenAllComplete([
139-
// self.unsubscribe(),
140-
// self.punsubscribe()
141-
// ], on: self.eventLoop)
136+
// if we're subscribed, and we're not allowed to be in pubsub, end our subscriptions
137+
guard self.isSubscribed && !self.allowPubSub.load() else { return }
138+
_ = EventLoopFuture<Void>.whenAllComplete([
139+
self.unsubscribe(),
140+
self.punsubscribe()
141+
], on: self.eventLoop)
142142
}
143143
}
144144

Tests/RediStackIntegrationTests/RedisConnectionTests.swift

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,28 @@ extension RedisConnectionTests {
5656
XCTAssertEqual(error, .pubsubNotAllowed)
5757
}
5858
}
59-
60-
// TODO - fix [p]unsubscribe from all and re-enable this unit test
61-
// func test_subscriptionPermissionsChanged_endsSubscriptions() throws {
62-
// let connection = try self.makeNewConnection()
63-
//
64-
// let channelSubClosedExpectation = self.expectation(description: "channel subscription was closed")
65-
// let patternSubClosedExpectation = self.expectation(description: "pattern subscription was closed")
66-
//
67-
// _ = connection.subscribe(
68-
// to: #function,
69-
// messageReceiver: { (_, _) in },
70-
// onUnsubscribe: { (_, _) in channelSubClosedExpectation.fulfill() }
71-
// )
72-
// _ = connection.psubscribe(
73-
// to: #function,
74-
// messageReceiver: { (_, _) in },
75-
// onUnsubscribe: { (_, _) in patternSubClosedExpectation.fulfill() }
76-
// )
77-
//
78-
// connection.allowSubscriptions = false
79-
//
80-
// self.waitForExpectations(timeout: 2)
81-
// }
59+
60+
func test_subscriptionPermissionsChanged_endsSubscriptions() throws {
61+
let connection = try self.makeNewConnection()
62+
63+
let subscriptionClosedExpectation = self.expectation(description: "subscription was closed")
64+
subscriptionClosedExpectation.expectedFulfillmentCount = 2
65+
66+
_ = try connection.subscribe(
67+
to: #function,
68+
messageReceiver: { _, _ in },
69+
onSubscribe: nil,
70+
onUnsubscribe: { _, _ in subscriptionClosedExpectation.fulfill() }
71+
).wait()
72+
_ = try connection.psubscribe(
73+
to: #function,
74+
messageReceiver: { _, _ in },
75+
onSubscribe: nil,
76+
onUnsubscribe: { _, _ in subscriptionClosedExpectation.fulfill() }
77+
).wait()
78+
79+
connection.allowSubscriptions = false
80+
81+
self.waitForExpectations(timeout: 1)
82+
}
8283
}

0 commit comments

Comments
 (0)