diff --git a/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift b/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift index 9860219e..652c3ba4 100644 --- a/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift +++ b/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift @@ -330,14 +330,13 @@ extension ValkeyChannelHandler { } @usableFromInline - enum GracefulShutdownAction { - case waitForPendingCommands(Context) + enum TriggerGracefulShutdownAction { case closeConnection(Context) case doNothing } /// Want to gracefully shutdown the handler @usableFromInline - mutating func gracefulShutdown() -> GracefulShutdownAction { + mutating func triggerGracefulShutdown() -> TriggerGracefulShutdownAction { switch consume self.state { case .initialized: self = .closed(nil) @@ -346,11 +345,11 @@ extension ValkeyChannelHandler { var pendingCommands = state.pendingCommands pendingCommands.prepend(state.pendingHelloCommand) self = .closing(.init(context: state.context, pendingCommands: pendingCommands)) - return .waitForPendingCommands(state.context) + return .doNothing case .active(let state): if state.pendingCommands.count > 0 { self = .closing(.init(context: state.context, pendingCommands: state.pendingCommands)) - return .waitForPendingCommands(state.context) + return .doNothing } else { self = .closed(nil) return .closeConnection(state.context) diff --git a/Sources/Valkey/Connection/ValkeyChannelHandler.swift b/Sources/Valkey/Connection/ValkeyChannelHandler.swift index 1e2448da..8cbe787e 100644 --- a/Sources/Valkey/Connection/ValkeyChannelHandler.swift +++ b/Sources/Valkey/Connection/ValkeyChannelHandler.swift @@ -517,6 +517,15 @@ final class ValkeyChannelHandler: ChannelInboundHandler { break } } + + func triggerGracefulShutdown() { + switch self.stateMachine.triggerGracefulShutdown() { + case .closeConnection(let context): + context.close(mode: .all, promise: nil) + case .doNothing: + break + } + } } @available(valkeySwift 1.0, *) diff --git a/Sources/Valkey/Connection/ValkeyConnection.swift b/Sources/Valkey/Connection/ValkeyConnection.swift index d8b500e8..79b393cf 100644 --- a/Sources/Valkey/Connection/ValkeyConnection.swift +++ b/Sources/Valkey/Connection/ValkeyConnection.swift @@ -164,6 +164,14 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable { try await self.channelHandler.waitOnActive().get() } + /// Trigger graceful shutdown of connection + /// + /// The connection will wait until all pending commands have been processed before + /// closing the connection. + func triggerGracefulShutdown() { + self.channelHandler.triggerGracefulShutdown() + } + /// Send RESP command to Valkey connection /// - Parameter command: ValkeyCommand structure /// - Returns: The command response as defined in the ValkeyCommand diff --git a/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift b/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift index 7aca82a8..26628259 100644 --- a/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift +++ b/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift @@ -146,7 +146,7 @@ struct ValkeyChannelHandlerStateMachineTests { var stateMachine = ValkeyChannelHandler.StateMachine() stateMachine.setConnected(context: "testGracefulShutdown") stateMachine.receiveHelloResponse() - switch stateMachine.gracefulShutdown() { + switch stateMachine.triggerGracefulShutdown() { case .closeConnection(let context): #expect(context == "testGracefulShutdown") default: @@ -168,10 +168,10 @@ struct ValkeyChannelHandlerStateMachineTests { case .throwError: Issue.record("Invalid sendCommand action") } - switch stateMachine.gracefulShutdown() { - case .waitForPendingCommands(let context): - #expect(context == "testGracefulShutdown") - default: + switch stateMachine.triggerGracefulShutdown() { + case .doNothing: + break + case .closeConnection: Issue.record("Invalid waitForPendingCommands action") } expect( @@ -207,10 +207,10 @@ struct ValkeyChannelHandlerStateMachineTests { case .throwError: Issue.record("Invalid sendCommand action") } - switch stateMachine.gracefulShutdown() { - case .waitForPendingCommands(let context): - #expect(context == "testClosedClosingState") - default: + switch stateMachine.triggerGracefulShutdown() { + case .doNothing: + break + case .closeConnection: Issue.record("Invalid waitForPendingCommands action") } expect( @@ -333,7 +333,7 @@ struct ValkeyChannelHandlerStateMachineTests { case .throwError: Issue.record("Invalid sendCommand action") } - _ = stateMachine.gracefulShutdown() + _ = stateMachine.triggerGracefulShutdown() switch stateMachine.cancel(requestID: 23) { case .failPendingCommandsAndClose(let context, let cancel, let closeConnectionDueToCancel): #expect(context == "testCancelGracefulShutdown") diff --git a/Tests/ValkeyTests/ValkeyConnectionTests.swift b/Tests/ValkeyTests/ValkeyConnectionTests.swift index 939aa1a7..f4d964ff 100644 --- a/Tests/ValkeyTests/ValkeyConnectionTests.swift +++ b/Tests/ValkeyTests/ValkeyConnectionTests.swift @@ -491,6 +491,28 @@ struct ConnectionTests { try await channel.close() } + @Test + @available(valkeySwift 1.0, *) + func testTriggerGracefulShutdown() async throws { + let channel = NIOAsyncTestingChannel() + let logger = Logger(label: "test") + let connection = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger) + try await channel.processHello() + + async let fooResult = connection.get("foo").map { String(buffer: $0) } + + let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self) + #expect(outbound == RESPToken(.command(["GET", "foo"])).base) + + await connection.triggerGracefulShutdown() + #expect(channel.isActive) + + try await channel.writeInbound(RESPToken(.bulkString("Bar")).base) + #expect(try await fooResult == "Bar") + + try await channel.closeFuture.get() + } + #if DistributedTracingSupport @Suite struct DistributedTracingTests {