Skip to content

Commit 7c9fd99

Browse files
committed
Add ValkeyConnection.triggerGracefulShutdown
Signed-off-by: Adam Fowler <[email protected]>
1 parent 62a2854 commit 7c9fd99

File tree

5 files changed

+53
-15
lines changed

5 files changed

+53
-15
lines changed

Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,13 @@ extension ValkeyChannelHandler {
330330
}
331331

332332
@usableFromInline
333-
enum GracefulShutdownAction {
334-
case waitForPendingCommands(Context)
333+
enum TriggerGracefulShutdownAction {
335334
case closeConnection(Context)
336335
case doNothing
337336
}
338337
/// Want to gracefully shutdown the handler
339338
@usableFromInline
340-
mutating func gracefulShutdown() -> GracefulShutdownAction {
339+
mutating func triggerGracefulShutdown() -> TriggerGracefulShutdownAction {
341340
switch consume self.state {
342341
case .initialized:
343342
self = .closed(nil)
@@ -346,11 +345,11 @@ extension ValkeyChannelHandler {
346345
var pendingCommands = state.pendingCommands
347346
pendingCommands.prepend(state.pendingHelloCommand)
348347
self = .closing(.init(context: state.context, pendingCommands: pendingCommands))
349-
return .waitForPendingCommands(state.context)
348+
return .doNothing
350349
case .active(let state):
351350
if state.pendingCommands.count > 0 {
352351
self = .closing(.init(context: state.context, pendingCommands: state.pendingCommands))
353-
return .waitForPendingCommands(state.context)
352+
return .doNothing
354353
} else {
355354
self = .closed(nil)
356355
return .closeConnection(state.context)

Sources/Valkey/Connection/ValkeyChannelHandler.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,15 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
517517
break
518518
}
519519
}
520+
521+
func triggerGracefulShutdown() {
522+
switch self.stateMachine.triggerGracefulShutdown() {
523+
case .closeConnection(let context):
524+
context.close(mode: .all, promise: nil)
525+
case .doNothing:
526+
break
527+
}
528+
}
520529
}
521530

522531
@available(valkeySwift 1.0, *)

Sources/Valkey/Connection/ValkeyConnection.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable {
164164
try await self.channelHandler.waitOnActive().get()
165165
}
166166

167+
/// Trigger graceful shutdown of connection
168+
///
169+
/// The connection will wait until all pending commands have been processed before
170+
/// closing the connection.
171+
func triggerGracefulShutdown() {
172+
self.channelHandler.triggerGracefulShutdown()
173+
}
174+
167175
/// Send RESP command to Valkey connection
168176
/// - Parameter command: ValkeyCommand structure
169177
/// - Returns: The command response as defined in the ValkeyCommand

Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ struct ValkeyChannelHandlerStateMachineTests {
146146
var stateMachine = ValkeyChannelHandler.StateMachine<String>()
147147
stateMachine.setConnected(context: "testGracefulShutdown")
148148
stateMachine.receiveHelloResponse()
149-
switch stateMachine.gracefulShutdown() {
149+
switch stateMachine.triggerGracefulShutdown() {
150150
case .closeConnection(let context):
151151
#expect(context == "testGracefulShutdown")
152152
default:
@@ -168,10 +168,10 @@ struct ValkeyChannelHandlerStateMachineTests {
168168
case .throwError:
169169
Issue.record("Invalid sendCommand action")
170170
}
171-
switch stateMachine.gracefulShutdown() {
172-
case .waitForPendingCommands(let context):
173-
#expect(context == "testGracefulShutdown")
174-
default:
171+
switch stateMachine.triggerGracefulShutdown() {
172+
case .doNothing:
173+
break
174+
case .closeConnection:
175175
Issue.record("Invalid waitForPendingCommands action")
176176
}
177177
expect(
@@ -207,10 +207,10 @@ struct ValkeyChannelHandlerStateMachineTests {
207207
case .throwError:
208208
Issue.record("Invalid sendCommand action")
209209
}
210-
switch stateMachine.gracefulShutdown() {
211-
case .waitForPendingCommands(let context):
212-
#expect(context == "testClosedClosingState")
213-
default:
210+
switch stateMachine.triggerGracefulShutdown() {
211+
case .doNothing:
212+
break
213+
case .closeConnection:
214214
Issue.record("Invalid waitForPendingCommands action")
215215
}
216216
expect(
@@ -333,7 +333,7 @@ struct ValkeyChannelHandlerStateMachineTests {
333333
case .throwError:
334334
Issue.record("Invalid sendCommand action")
335335
}
336-
_ = stateMachine.gracefulShutdown()
336+
_ = stateMachine.triggerGracefulShutdown()
337337
switch stateMachine.cancel(requestID: 23) {
338338
case .failPendingCommandsAndClose(let context, let cancel, let closeConnectionDueToCancel):
339339
#expect(context == "testCancelGracefulShutdown")

Tests/ValkeyTests/ValkeyConnectionTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,28 @@ struct ConnectionTests {
491491
try await channel.close()
492492
}
493493

494+
@Test
495+
@available(valkeySwift 1.0, *)
496+
func testTriggerGracefulShutdown() async throws {
497+
let channel = NIOAsyncTestingChannel()
498+
let logger = Logger(label: "test")
499+
let connection = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger)
500+
try await channel.processHello()
501+
502+
async let fooResult = connection.get("foo").map { String(buffer: $0) }
503+
504+
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
505+
#expect(outbound == RESPToken(.command(["GET", "foo"])).base)
506+
507+
await connection.triggerGracefulShutdown()
508+
#expect(channel.isActive)
509+
510+
try await channel.writeInbound(RESPToken(.bulkString("Bar")).base)
511+
#expect(try await fooResult == "Bar")
512+
513+
try await channel.closeFuture.get()
514+
}
515+
494516
#if DistributedTracingSupport
495517
@Suite
496518
struct DistributedTracingTests {

0 commit comments

Comments
 (0)