Skip to content

Commit c85f857

Browse files
committed
Add support for graceful shutdown to the RedisCommandHandler
1 parent b5c0bc8 commit c85f857

File tree

2 files changed

+135
-3
lines changed

2 files changed

+135
-3
lines changed

Sources/RediStack/ChannelHandlers/RedisCommandHandler.swift

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ public final class RedisCommandHandler {
5353
}
5454

5555
private enum State {
56-
case `default`, error(Error)
56+
case `default`
57+
case draining(EventLoopPromise<Void>?)
58+
case error(Error)
5759
}
5860
}
5961

@@ -79,6 +81,7 @@ extension RedisCommandHandler: ChannelInboundHandler {
7981
/// See `NIO.ChannelInboundHandler.channelInactive(context:)`
8082
/// - Note: `RedisMetrics.commandFailureCount` is **not** incremented from this method.
8183
public func channelInactive(context: ChannelHandlerContext) {
84+
self.state = .error(RedisClientError.connectionClosed)
8285
self._failCommandQueue(because: RedisClientError.connectionClosed)
8386
}
8487

@@ -109,6 +112,16 @@ extension RedisCommandHandler: ChannelInboundHandler {
109112
leadPromise.succeed(value)
110113
RedisMetrics.commandSuccessCount.increment()
111114
}
115+
116+
switch self.state {
117+
case .draining(let promise):
118+
if self.commandResponseQueue.isEmpty {
119+
context.close(mode: .all, promise: promise)
120+
}
121+
122+
case .error, .`default`:
123+
break
124+
}
112125
}
113126
}
114127

@@ -130,7 +143,11 @@ extension RedisCommandHandler: ChannelOutboundHandler {
130143
let commandContext = self.unwrapOutboundIn(data)
131144

132145
switch self.state {
133-
case let .error(e): commandContext.responsePromise.fail(e)
146+
case let .error(e):
147+
commandContext.responsePromise.fail(e)
148+
149+
case .draining:
150+
commandContext.responsePromise.fail(RedisClientError.connectionClosed)
134151

135152
case .default:
136153
self.commandResponseQueue.append(commandContext.responsePromise)
@@ -140,4 +157,39 @@ extension RedisCommandHandler: ChannelOutboundHandler {
140157
)
141158
}
142159
}
160+
161+
/// Listens for ``RedisGracefulConnectionCloseEvent``. If such an event is received the handler will wait
162+
/// until all currently running commands have returned. Once all requests are fulfilled the handler will close the channel.
163+
///
164+
/// If a command is sent on the channel, after the ``RedisGracefulConnectionCloseEvent`` was scheduled,
165+
/// the command will be failed with a ``RedisClientError/connectionClosed``.
166+
///
167+
/// See `NIO.ChannelOutboundHandler.triggerUserOutboundEvent(context:event:promise:)`
168+
public func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
169+
switch event {
170+
case is RedisGracefulConnectionCloseEvent:
171+
switch self.state {
172+
case .default:
173+
if self.commandResponseQueue.isEmpty {
174+
self.state = .error(RedisClientError.connectionClosed)
175+
context.close(mode: .all, promise: promise)
176+
} else {
177+
self.state = .draining(promise)
178+
}
179+
180+
case .error, .draining:
181+
promise?.succeed(())
182+
break
183+
}
184+
185+
default:
186+
context.triggerUserOutboundEvent(event, promise: promise)
187+
}
188+
}
189+
}
190+
191+
/// A channel event that informs the ``RedisCommandHandler`` that it should close the channel gracefully
192+
public struct RedisGracefulConnectionCloseEvent {
193+
/// Creates a ``RedisGracefulConnectionCloseEvent``
194+
public init() {}
143195
}

Tests/RediStackTests/ChannelHandlers/RedisCommandHandlerTests.swift

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15-
import NIO
15+
import NIOCore
16+
import NIOPosix
17+
import NIOEmbedded
18+
import Atomics
1619
@testable import RediStack
1720
import XCTest
1821

@@ -43,6 +46,83 @@ final class RedisCommandHandlerTests: XCTestCase {
4346
XCTAssertEqual(error, .connectionClosed)
4447
}
4548
}
49+
50+
func testCloseIsTriggeredOnceCommandQueueIsEmpty() {
51+
let loop = EmbeddedEventLoop()
52+
let channel = EmbeddedChannel(handler: RedisCommandHandler(), loop: loop)
53+
54+
XCTAssertNoThrow(try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait())
55+
XCTAssertTrue(channel.isActive)
56+
57+
let getFoo = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "foo"))])
58+
let promiseFoo = loop.makePromise(of: RESPValue.self)
59+
let commandFoo = (message: getFoo, responsePromise: promiseFoo)
60+
XCTAssertNoThrow(try channel.writeOutbound(commandFoo))
61+
XCTAssertEqual(try channel.readOutbound(as: RESPValue.self), getFoo)
62+
63+
let getBar = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "bar"))])
64+
let promiseBar = loop.makePromise(of: RESPValue.self)
65+
let commandBar = (message: getBar, responsePromise: promiseBar)
66+
XCTAssertNoThrow(try channel.writeOutbound(commandBar))
67+
XCTAssertEqual(try channel.readOutbound(as: RESPValue.self), getBar)
68+
69+
let getBaz = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "baz"))])
70+
let promiseBaz = loop.makePromise(of: RESPValue.self)
71+
let commandBaz = (message: getBaz, responsePromise: promiseBaz)
72+
XCTAssertNoThrow(try channel.writeOutbound(commandBaz))
73+
XCTAssertEqual(try channel.readOutbound(as: RESPValue.self), getBaz)
74+
75+
let gracefulClosePromise = loop.makePromise(of: Void.self)
76+
let channelCloseHitCounter = ManagedAtomic<Int>(0)
77+
gracefulClosePromise.futureResult.whenComplete { _ in
78+
channelCloseHitCounter.wrappingIncrement(ordering: .relaxed)
79+
}
80+
channel.triggerUserOutboundEvent(RedisGracefulConnectionCloseEvent(), promise: gracefulClosePromise)
81+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 0)
82+
83+
let fooResponse = RESPValue.simpleString(.init(string: "fooresult"))
84+
XCTAssertNoThrow(try channel.writeInbound(fooResponse))
85+
XCTAssertTrue(channel.isActive)
86+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 0)
87+
XCTAssertEqual(try promiseFoo.futureResult.wait(), fooResponse)
88+
89+
let barResponse = RESPValue.simpleString(.init(string: "barresult"))
90+
XCTAssertNoThrow(try channel.writeInbound(barResponse))
91+
XCTAssertTrue(channel.isActive)
92+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 0)
93+
XCTAssertEqual(try promiseBar.futureResult.wait(), barResponse)
94+
95+
let bazResponse = RESPValue.simpleString(.init(string: "bazresult"))
96+
XCTAssertNoThrow(try channel.writeInbound(bazResponse))
97+
XCTAssertEqual(try promiseBaz.futureResult.wait(), bazResponse)
98+
XCTAssertFalse(channel.isActive)
99+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 1)
100+
XCTAssertNoThrow(try gracefulClosePromise.futureResult.wait())
101+
}
102+
103+
func testCloseIsTriggeredRightAwayIfCommandQueueIsEmpty() {
104+
let loop = EmbeddedEventLoop()
105+
let channel = EmbeddedChannel(handler: RedisCommandHandler(), loop: loop)
106+
XCTAssertNoThrow(try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait())
107+
XCTAssertTrue(channel.isActive)
108+
109+
let gracefulClosePromise = loop.makePromise(of: Void.self)
110+
let gracefulCloseHitCounter = ManagedAtomic<Int>(0)
111+
gracefulClosePromise.futureResult.whenComplete { _ in
112+
gracefulCloseHitCounter.wrappingIncrement(ordering: .relaxed)
113+
}
114+
channel.triggerUserOutboundEvent(RedisGracefulConnectionCloseEvent(), promise: gracefulClosePromise)
115+
XCTAssertFalse(channel.isActive)
116+
XCTAssertEqual(gracefulCloseHitCounter.load(ordering: .relaxed), 1)
117+
118+
let getBar = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "bar"))])
119+
let promiseBar = loop.makePromise(of: RESPValue.self)
120+
let commandBar = (message: getBar, responsePromise: promiseBar)
121+
channel.write(commandBar, promise: nil)
122+
XCTAssertThrowsError(try promiseBar.futureResult.wait()) {
123+
XCTAssertEqual($0 as? RedisClientError, .connectionClosed)
124+
}
125+
}
46126
}
47127

48128
private final class RemoteCloseHandler: ChannelInboundHandler {

0 commit comments

Comments
 (0)