Skip to content

Commit 4ea66c4

Browse files
fabianfettMordil
authored andcommitted
Add support for graceful shutdown to the RedisCommandHandler
1 parent b88fac0 commit 4ea66c4

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

Sources/RediStack/ChannelHandlers/RedisCommandHandler.swift

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ public final class RedisCommandHandler {
4444
}
4545

4646
private enum State {
47-
case `default`, error(Error)
47+
case `default`
48+
case draining(EventLoopPromise<Void>?)
49+
case error(Error)
4850
}
4951
}
5052

@@ -70,6 +72,7 @@ extension RedisCommandHandler: ChannelInboundHandler {
7072
/// See `NIO.ChannelInboundHandler.channelInactive(context:)`
7173
/// - Note: `RedisMetrics.commandFailureCount` is **not** incremented from this method.
7274
public func channelInactive(context: ChannelHandlerContext) {
75+
self.state = .error(RedisClientError.connectionClosed)
7376
self._failCommandQueue(because: RedisClientError.connectionClosed)
7477
}
7578

@@ -100,6 +103,16 @@ extension RedisCommandHandler: ChannelInboundHandler {
100103
leadPromise.succeed(value)
101104
RedisMetrics.commandSuccessCount.increment()
102105
}
106+
107+
switch self.state {
108+
case .draining(let promise):
109+
if self.commandResponseQueue.isEmpty {
110+
context.close(mode: .all, promise: promise)
111+
}
112+
113+
case .error, .`default`:
114+
break
115+
}
103116
}
104117
}
105118

@@ -119,7 +132,11 @@ extension RedisCommandHandler: ChannelOutboundHandler {
119132
let commandPayload = self.unwrapOutboundIn(data)
120133

121134
switch self.state {
122-
case let .error(e): commandPayload.responsePromise.fail(e)
135+
case let .error(e):
136+
commandPayload.responsePromise.fail(e)
137+
138+
case .draining:
139+
commandPayload.responsePromise.fail(RedisClientError.connectionClosed)
123140

124141
case .default:
125142
self.commandResponseQueue.append(commandPayload.responsePromise)
@@ -129,4 +146,39 @@ extension RedisCommandHandler: ChannelOutboundHandler {
129146
)
130147
}
131148
}
149+
150+
/// Listens for ``RedisGracefulConnectionCloseEvent``. If such an event is received the handler will wait
151+
/// until all currently running commands have returned. Once all requests are fulfilled the handler will close the channel.
152+
///
153+
/// If a command is sent on the channel, after the ``RedisGracefulConnectionCloseEvent`` was scheduled,
154+
/// the command will be failed with a ``RedisClientError/connectionClosed``.
155+
///
156+
/// See `NIO.ChannelOutboundHandler.triggerUserOutboundEvent(context:event:promise:)`
157+
public func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
158+
switch event {
159+
case is RedisGracefulConnectionCloseEvent:
160+
switch self.state {
161+
case .default:
162+
if self.commandResponseQueue.isEmpty {
163+
self.state = .error(RedisClientError.connectionClosed)
164+
context.close(mode: .all, promise: promise)
165+
} else {
166+
self.state = .draining(promise)
167+
}
168+
169+
case .error, .draining:
170+
promise?.succeed(())
171+
break
172+
}
173+
174+
default:
175+
context.triggerUserOutboundEvent(event, promise: promise)
176+
}
177+
}
178+
}
179+
180+
/// A channel event that informs the ``RedisCommandHandler`` that it should close the channel gracefully
181+
public struct RedisGracefulConnectionCloseEvent {
182+
/// Creates a ``RedisGracefulConnectionCloseEvent``
183+
public init() {}
132184
}

Tests/RediStackTests/ChannelHandlers/RedisCommandHandlerTests.swift

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import NIOCore
1616
import NIOPosix
17+
import NIOEmbedded
18+
import Atomics
1719
@testable import RediStack
1820
import XCTest
1921

@@ -44,6 +46,83 @@ final class RedisCommandHandlerTests: XCTestCase {
4446
XCTAssertEqual(error, .connectionClosed)
4547
}
4648
}
49+
50+
func testCloseIsTriggeredOnceCommandQueueIsEmpty() {
51+
let loop = EmbeddedEventLoop()
52+
let channel = EmbeddedChannel(handler: RedisCommandHandler(), loop: loop)
53+
54+
XCTAssertNoThrow(try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait())
55+
XCTAssertTrue(channel.isActive)
56+
57+
let getFoo = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "foo"))])
58+
let promiseFoo = loop.makePromise(of: RESPValue.self)
59+
let commandFoo = (message: getFoo, responsePromise: promiseFoo)
60+
XCTAssertNoThrow(try channel.writeOutbound(commandFoo))
61+
XCTAssertEqual(try channel.readOutbound(as: RESPValue.self), getFoo)
62+
63+
let getBar = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "bar"))])
64+
let promiseBar = loop.makePromise(of: RESPValue.self)
65+
let commandBar = (message: getBar, responsePromise: promiseBar)
66+
XCTAssertNoThrow(try channel.writeOutbound(commandBar))
67+
XCTAssertEqual(try channel.readOutbound(as: RESPValue.self), getBar)
68+
69+
let getBaz = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "baz"))])
70+
let promiseBaz = loop.makePromise(of: RESPValue.self)
71+
let commandBaz = (message: getBaz, responsePromise: promiseBaz)
72+
XCTAssertNoThrow(try channel.writeOutbound(commandBaz))
73+
XCTAssertEqual(try channel.readOutbound(as: RESPValue.self), getBaz)
74+
75+
let gracefulClosePromise = loop.makePromise(of: Void.self)
76+
let channelCloseHitCounter = ManagedAtomic<Int>(0)
77+
gracefulClosePromise.futureResult.whenComplete { _ in
78+
channelCloseHitCounter.wrappingIncrement(ordering: .relaxed)
79+
}
80+
channel.triggerUserOutboundEvent(RedisGracefulConnectionCloseEvent(), promise: gracefulClosePromise)
81+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 0)
82+
83+
let fooResponse = RESPValue.simpleString(.init(string: "fooresult"))
84+
XCTAssertNoThrow(try channel.writeInbound(fooResponse))
85+
XCTAssertTrue(channel.isActive)
86+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 0)
87+
XCTAssertEqual(try promiseFoo.futureResult.wait(), fooResponse)
88+
89+
let barResponse = RESPValue.simpleString(.init(string: "barresult"))
90+
XCTAssertNoThrow(try channel.writeInbound(barResponse))
91+
XCTAssertTrue(channel.isActive)
92+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 0)
93+
XCTAssertEqual(try promiseBar.futureResult.wait(), barResponse)
94+
95+
let bazResponse = RESPValue.simpleString(.init(string: "bazresult"))
96+
XCTAssertNoThrow(try channel.writeInbound(bazResponse))
97+
XCTAssertEqual(try promiseBaz.futureResult.wait(), bazResponse)
98+
XCTAssertFalse(channel.isActive)
99+
XCTAssertEqual(channelCloseHitCounter.load(ordering: .relaxed), 1)
100+
XCTAssertNoThrow(try gracefulClosePromise.futureResult.wait())
101+
}
102+
103+
func testCloseIsTriggeredRightAwayIfCommandQueueIsEmpty() {
104+
let loop = EmbeddedEventLoop()
105+
let channel = EmbeddedChannel(handler: RedisCommandHandler(), loop: loop)
106+
XCTAssertNoThrow(try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait())
107+
XCTAssertTrue(channel.isActive)
108+
109+
let gracefulClosePromise = loop.makePromise(of: Void.self)
110+
let gracefulCloseHitCounter = ManagedAtomic<Int>(0)
111+
gracefulClosePromise.futureResult.whenComplete { _ in
112+
gracefulCloseHitCounter.wrappingIncrement(ordering: .relaxed)
113+
}
114+
channel.triggerUserOutboundEvent(RedisGracefulConnectionCloseEvent(), promise: gracefulClosePromise)
115+
XCTAssertFalse(channel.isActive)
116+
XCTAssertEqual(gracefulCloseHitCounter.load(ordering: .relaxed), 1)
117+
118+
let getBar = RESPValue.array([.bulkString(.init(string: "GET")), .bulkString(.init(string: "bar"))])
119+
let promiseBar = loop.makePromise(of: RESPValue.self)
120+
let commandBar = (message: getBar, responsePromise: promiseBar)
121+
channel.write(commandBar, promise: nil)
122+
XCTAssertThrowsError(try promiseBar.futureResult.wait()) {
123+
XCTAssertEqual($0 as? RedisClientError, .connectionClosed)
124+
}
125+
}
47126
}
48127

49128
private final class RemoteCloseHandler: ChannelInboundHandler {

0 commit comments

Comments
 (0)