Skip to content

Commit f736fa6

Browse files
authored
Cancellation (#75)
* Add send/pipeline cancallation * Add subscription cancellation * Add request id * Pass requestID down from ValkeyConnection * ValkeyConnection.requestIDGenerator * Add pending cancelled requests array * Close connection if cancelled * Remove optional promise, rename commands to pendingCommands * Remove cancelledRequests as they have been moved to state machine * Only cancel command and close connection if command is pending * Add already cancelled test * pipeline should not throw as it returns Results
1 parent 35f0d13 commit f736fa6

9 files changed

+447
-101
lines changed

Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,29 @@ extension ValkeyChannelHandler {
7373
}
7474
}
7575

76+
@usableFromInline
77+
enum CancelAction {
78+
case cancelAndCloseConnection(Context)
79+
case doNothing
80+
}
81+
82+
/// handler wants to send a command
83+
@usableFromInline
84+
mutating func cancel() -> CancelAction {
85+
switch self.state {
86+
case .initializing:
87+
preconditionFailure("Cannot cancel when initializing")
88+
case .active(let state):
89+
self.state = .closed
90+
return .cancelAndCloseConnection(state.context)
91+
case .closing(let state):
92+
self.state = .closed
93+
return .cancelAndCloseConnection(state.context)
94+
case .closed:
95+
return .doNothing
96+
}
97+
}
98+
7699
@usableFromInline
77100
enum GracefulShutdownAction {
78101
case waitForPendingCommands(Context)

Sources/Valkey/Connection/ValkeyChannelHandler.swift

Lines changed: 82 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ enum ValkeyPromise<T: Sendable>: Sendable {
4242

4343
@usableFromInline
4444
enum ValkeyRequest: Sendable {
45-
case single(buffer: ByteBuffer, promise: ValkeyPromise<RESPToken>)
46-
case multiple(buffer: ByteBuffer, promises: [ValkeyPromise<RESPToken>])
45+
case single(buffer: ByteBuffer, promise: ValkeyPromise<RESPToken>, id: Int)
46+
case multiple(buffer: ByteBuffer, promises: [ValkeyPromise<RESPToken>], id: Int)
4747
}
4848

4949
@usableFromInline
@@ -53,6 +53,17 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
5353
let clientName: String?
5454
}
5555
@usableFromInline
56+
struct PendingCommand {
57+
@usableFromInline
58+
internal init(promise: ValkeyPromise<RESPToken>, requestID: Int) {
59+
self.promise = promise
60+
self.requestID = requestID
61+
}
62+
63+
var promise: ValkeyPromise<RESPToken>
64+
let requestID: Int
65+
}
66+
@usableFromInline
5667
typealias OutboundOut = ByteBuffer
5768
@usableFromInline
5869
typealias InboundIn = ByteBuffer
@@ -61,7 +72,7 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
6172
@usableFromInline
6273
/*private*/ let eventLoop: EventLoop
6374
@usableFromInline
64-
/*private*/ var commands: Deque<ValkeyPromise<RESPToken>>
75+
/*private*/ var pendingCommands: Deque<PendingCommand>
6576
@usableFromInline
6677
/*private*/ var encoder = ValkeyCommandEncoder()
6778
@usableFromInline
@@ -77,7 +88,7 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
7788
init(configuration: Configuration, eventLoop: EventLoop, logger: Logger) {
7889
self.configuration = configuration
7990
self.eventLoop = eventLoop
80-
self.commands = .init()
91+
self.pendingCommands = .init()
8192
self.subscriptions = .init(logger: logger)
8293
self.decoder = NIOSingleStepByteToMessageProcessor(RESPTokenDecoder())
8394
self.stateMachine = .init()
@@ -89,15 +100,15 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
89100
/// - request: Valkey command request
90101
/// - promise: Promise to fulfill when command is complete
91102
@inlinable
92-
func write<Command: ValkeyCommand>(command: Command, continuation: CheckedContinuation<RESPToken, any Error>) {
103+
func write<Command: ValkeyCommand>(command: Command, continuation: CheckedContinuation<RESPToken, any Error>, requestID: Int) {
93104
self.eventLoop.assertInEventLoop()
94105
switch self.stateMachine.sendCommand() {
95106
case .sendCommand(let context):
96107
self.encoder.reset()
97108
command.encode(into: &self.encoder)
98109
let buffer = self.encoder.buffer
99110

100-
self.commands.append(.swift(continuation))
111+
self.pendingCommands.append(.init(promise: .swift(continuation), requestID: requestID))
101112
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
102113

103114
case .throwError(let error):
@@ -108,26 +119,24 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
108119
@usableFromInline
109120
func write(request: ValkeyRequest) {
110121
self.eventLoop.assertInEventLoop()
111-
switch self.stateMachine.sendCommand() {
112-
case .sendCommand(let context):
113-
switch request {
114-
case .single(let buffer, let tokenPromise):
115-
self.commands.append(tokenPromise)
122+
switch request {
123+
case .single(let buffer, let tokenPromise, let requestID):
124+
switch self.stateMachine.sendCommand() {
125+
case .sendCommand(let context):
126+
self.pendingCommands.append(.init(promise: tokenPromise, requestID: requestID))
116127
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
128+
case .throwError(let error):
129+
tokenPromise.fail(error)
130+
}
117131

118-
case .multiple(let buffer, let tokenPromises):
132+
case .multiple(let buffer, let tokenPromises, let requestID):
133+
switch self.stateMachine.sendCommand() {
134+
case .sendCommand(let context):
119135
for tokenPromise in tokenPromises {
120-
self.commands.append(tokenPromise)
136+
self.pendingCommands.append(.init(promise: tokenPromise, requestID: requestID))
121137
}
122138
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
123-
}
124-
125-
case .throwError(let error):
126-
switch request {
127-
case .single(_, let tokenPromise):
128-
tokenPromise.fail(error)
129-
130-
case .multiple(_, let tokenPromises):
139+
case .throwError(let error):
131140
for promise in tokenPromises {
132141
promise.fail(error)
133142
}
@@ -140,7 +149,8 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
140149
command: some ValkeyCommand,
141150
streamContinuation: ValkeySubscription.Continuation,
142151
filters: [ValkeySubscriptionFilter],
143-
promise: ValkeyPromise<Int>
152+
promise: ValkeyPromise<Int>,
153+
requestID: Int
144154
) {
145155
self.eventLoop.assertInEventLoop()
146156
switch self.subscriptions.addSubscription(continuation: streamContinuation, filters: filters) {
@@ -149,7 +159,7 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
149159
// But it would be cool to build the subscribe command based on what filters we aren't subscribed to
150160
self.subscriptions.pushCommand(filters: subscription.filters)
151161
let subscriptionID = subscription.id
152-
return self._send(command: command).assumeIsolated().whenComplete { result in
162+
return self._send(command: command, requestID: requestID).assumeIsolated().whenComplete { result in
153163
switch result {
154164
case .success:
155165
promise.succeed(subscriptionID)
@@ -168,27 +178,31 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
168178
/// Remove subscription and if required call UNSUBSCRIBE command
169179
func unsubscribe(
170180
id: Int,
171-
promise: ValkeyPromise<Void>
181+
promise: ValkeyPromise<Void>,
182+
requestID: Int
172183
) {
173184
self.eventLoop.assertInEventLoop()
174185
switch self.subscriptions.unsubscribe(id: id) {
175186
case .unsubscribe(let channels):
176187
self.performUnsubscribe(
177188
command: UNSUBSCRIBE(channel: channels),
178189
filters: channels.map { .channel($0) },
179-
promise: promise
190+
promise: promise,
191+
requestID: requestID
180192
)
181193
case .punsubscribe(let patterns):
182194
self.performUnsubscribe(
183195
command: PUNSUBSCRIBE(pattern: patterns),
184196
filters: patterns.map { .pattern($0) },
185-
promise: promise
197+
promise: promise,
198+
requestID: requestID
186199
)
187200
case .sunsubscribe(let shardChannels):
188201
self.performUnsubscribe(
189202
command: SUNSUBSCRIBE(shardchannel: shardChannels),
190203
filters: shardChannels.map { .shardChannel($0) },
191-
promise: promise
204+
promise: promise,
205+
requestID: requestID
192206
)
193207
case .doNothing:
194208
promise.succeed(())
@@ -198,10 +212,11 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
198212
func performUnsubscribe(
199213
command: some ValkeyCommand,
200214
filters: [ValkeySubscriptionFilter],
201-
promise: ValkeyPromise<Void>
215+
promise: ValkeyPromise<Void>,
216+
requestID: Int
202217
) {
203218
self.subscriptions.pushCommand(filters: filters)
204-
self._send(command: command).assumeIsolated().whenComplete { result in
219+
self._send(command: command, requestID: requestID).assumeIsolated().whenComplete { result in
205220
switch result {
206221
case .success:
207222
promise.succeed(())
@@ -222,7 +237,8 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
222237
auth: configuration.authentication.map { .init(username: $0.username, password: $0.password) },
223238
clientname: configuration.clientName
224239
)
225-
)
240+
),
241+
requestID: 0
226242
).assumeIsolated().whenComplete { result in
227243
switch result {
228244
case .failure(let error):
@@ -283,27 +299,51 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
283299
}
284300
}
285301

302+
@usableFromInline
303+
func cancel(requestID: Int) {
304+
self.eventLoop.assertInEventLoop()
305+
// if pending commands include request then we are still waiting for its result.
306+
// We should cancel that command, cancel all the other pending commands with error
307+
// code `.connectionClosedDueToCancellation` and close the connection
308+
if self.pendingCommands.contains(where: { $0.requestID == requestID }) {
309+
switch self.stateMachine.cancel() {
310+
case .cancelAndCloseConnection(let context):
311+
while let command = self.pendingCommands.popFirst() {
312+
if command.requestID == requestID {
313+
command.promise.fail(ValkeyClientError(.cancelled))
314+
} else {
315+
command.promise.fail(ValkeyClientError(.connectionClosedDueToCancellation))
316+
}
317+
}
318+
context.close(promise: nil)
319+
320+
case .doNothing:
321+
break
322+
}
323+
}
324+
}
325+
286326
func handleToken(context: ChannelHandlerContext, token: RESPToken) {
287327
switch token.identifier {
288328
case .simpleError, .bulkError:
289-
guard let promise = commands.popFirst() else {
329+
guard let command = pendingCommands.popFirst() else {
290330
self.failPendingCommandsAndSubscriptionsAndCloseConnection(
291331
ValkeyClientError(.unsolicitedToken, message: "Received an error token without having sent a command"),
292332
context: context
293333
)
294334
return
295335
}
296-
promise.fail(ValkeyClientError(.commandError, message: token.errorString.map { String(buffer: $0) }))
336+
command.promise.fail(ValkeyClientError(.commandError, message: token.errorString.map { String(buffer: $0) }))
297337

298338
case .push:
299339
// If subscription notify throws an error then assume something has gone wrong
300340
// and close the channel with the error
301341
do {
302342
if try self.subscriptions.notify(token) == true {
303-
guard let promise = commands.popFirst() else {
343+
guard let command = pendingCommands.popFirst() else {
304344
preconditionFailure("Unexpected response")
305345
}
306-
promise.succeed(Self.simpleOk)
346+
command.promise.succeed(Self.simpleOk)
307347
}
308348
} catch {
309349
context.close(mode: .all, promise: nil)
@@ -321,32 +361,32 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
321361
.map,
322362
.set,
323363
.attribute:
324-
guard let promise = commands.popFirst() else {
364+
guard let command = pendingCommands.popFirst() else {
325365
self.failPendingCommandsAndSubscriptionsAndCloseConnection(
326366
ValkeyClientError(.unsolicitedToken, message: "Received a token without having sent a command"),
327367
context: context
328368
)
329369
return
330370
}
331-
promise.succeed(token)
371+
command.promise.succeed(token)
332372
}
333373
}
334374

335375
func handleError(context: ChannelHandlerContext, error: Error) {
336376
self.logger.debug("ValkeyCommandHandler: ERROR", metadata: ["error": "\(error)"])
337-
guard let promise = commands.popFirst() else {
377+
guard let command = pendingCommands.popFirst() else {
338378
self.failPendingCommandsAndSubscriptionsAndCloseConnection(
339379
ValkeyClientError(.unsolicitedToken, message: "Received an error decoding a token without having sent a command"),
340380
context: context
341381
)
342382
return
343383
}
344-
promise.fail(error)
384+
command.promise.fail(error)
345385
}
346386

347387
private func failPendingCommandsAndSubscriptions(_ error: any Error) {
348-
while let promise = self.commands.popFirst() {
349-
promise.fail(error)
388+
while let command = self.pendingCommands.popFirst() {
389+
command.promise.fail(error)
350390
}
351391
self.subscriptions.close(error: error)
352392
}
@@ -358,14 +398,14 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
358398
}
359399

360400
// Function used internally by subscribe
361-
func _send<Command: ValkeyCommand>(command: Command) -> EventLoopFuture<RESPToken> {
401+
func _send<Command: ValkeyCommand>(command: Command, requestID: Int) -> EventLoopFuture<RESPToken> {
362402
self.eventLoop.assertInEventLoop()
363403
self.encoder.reset()
364404
command.encode(into: &self.encoder)
365405
let buffer = self.encoder.buffer
366406

367407
let promise = eventLoop.makePromise(of: RESPToken.self)
368-
self.write(request: ValkeyRequest.single(buffer: buffer, promise: .nio(promise)))
408+
self.write(request: ValkeyRequest.single(buffer: buffer, promise: .nio(promise), id: requestID))
369409
return promise.futureResult
370410
}
371411

0 commit comments

Comments
 (0)