Skip to content

Commit 4be5cb3

Browse files
committed
Add support for sending additional command in handshake
- Add ignore case to ValkeyPromise for commands whose results we don't care about - Add deque of pending commands to ConnectedState - Push remaining pending commands to active state once we receive the hello Signed-off-by: Adam Fowler <[email protected]>
1 parent 6058030 commit 4be5cb3

File tree

5 files changed

+91
-41
lines changed

5 files changed

+91
-41
lines changed

Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ extension ValkeyChannelHandler {
6565
@usableFromInline
6666
struct ConnectedState {
6767
let context: Context
68-
var pendingHelloCommand: PendingCommand
68+
var pendingCommands: Deque<PendingCommand>
6969

70-
func cancel(requestID: Int) -> PendingCommand? {
71-
if pendingHelloCommand.requestID == requestID {
72-
return pendingHelloCommand
70+
func cancel(requestID: Int) -> Deque<PendingCommand>? {
71+
if pendingCommands.first?.requestID == requestID {
72+
return pendingCommands
7373
}
7474
return nil
7575
}
@@ -85,11 +85,11 @@ extension ValkeyChannelHandler {
8585

8686
/// handler has become active
8787
@usableFromInline
88-
mutating func setConnected(context: Context, pendingHelloCommand: PendingCommand) {
88+
mutating func setConnected(context: Context, pendingCommands: Deque<PendingCommand>) {
8989
switch consume self.state {
9090
case .initialized:
9191
self = .connected(
92-
.init(context: context, pendingHelloCommand: pendingHelloCommand)
92+
.init(context: context, pendingCommands: pendingCommands)
9393
)
9494
case .connected:
9595
preconditionFailure("Cannot set connected state when state is connected")
@@ -155,15 +155,18 @@ extension ValkeyChannelHandler {
155155
switch consume self.state {
156156
case .initialized:
157157
preconditionFailure("Cannot send command when initialized")
158-
case .connected(let state):
158+
case .connected(var state):
159+
guard let helloCommand = state.pendingCommands.popFirst() else {
160+
preconditionFailure("Cannot be in connected state with no pending commands")
161+
}
159162
switch token.value {
160163
case .bulkError(let message), .simpleError(let message):
161164
let error = ValkeyClientError(.commandError, message: String(buffer: message))
162165
self = .closed(error)
163-
return .respondAndClose(state.pendingHelloCommand, error)
166+
return .respondAndClose(helloCommand, error)
164167
default:
165-
self = .active(.init(context: state.context, pendingCommands: .init()))
166-
return .respond(state.pendingHelloCommand, .cancel)
168+
self = .active(.init(context: state.context, pendingCommands: state.pendingCommands))
169+
return .respond(helloCommand, .cancel)
167170
}
168171
case .active(var state):
169172
guard let command = state.pendingCommands.popFirst() else {
@@ -204,8 +207,12 @@ extension ValkeyChannelHandler {
204207
self = .closed(nil)
205208
return .respondAndClose(command, nil)
206209
}
207-
case .closed:
208-
preconditionFailure("Cannot receive command on closed connection")
210+
case .closed(let error):
211+
guard let error else {
212+
preconditionFailure("Cannot receive command on closed connection with no error")
213+
}
214+
self = .closed(error)
215+
return .closeWithError(error)
209216
}
210217
}
211218

@@ -221,19 +228,22 @@ extension ValkeyChannelHandler {
221228
case .initialized:
222229
preconditionFailure("Cannot wait until connection has succeeded")
223230
case .connected(let state):
224-
switch state.pendingHelloCommand.promise {
231+
guard let helloCommand = state.pendingCommands.first else {
232+
preconditionFailure("Cannot be in connected state with no pending commands")
233+
}
234+
switch helloCommand.promise {
225235
case .nio(let promise):
226236
self = .connected(state)
227237
return .waitForPromise(promise)
228-
case .swift:
238+
case .swift, .ignore:
229239
preconditionFailure("Connected state cannot be setup with a Swift continuation")
230240
}
231241
case .active(let state):
232242
self = .active(state)
233243
return .done
234244
case .closing(let state):
235245
self = .closing(state)
236-
return .done
246+
return .reportedClosed(nil)
237247
case .closed(let error):
238248
self = .closed(error)
239249
return .reportedClosed(error)
@@ -253,12 +263,15 @@ extension ValkeyChannelHandler {
253263
case .initialized:
254264
preconditionFailure("Cannot cancel when initialized")
255265
case .connected(let state):
256-
if state.pendingHelloCommand.deadline <= now {
266+
guard let helloCommand = state.pendingCommands.first else {
267+
preconditionFailure("Cannot be in connected state with no pending commands")
268+
}
269+
if helloCommand.deadline <= now {
257270
self = .closed(ValkeyClientError(.timeout))
258-
return .failPendingCommandsAndClose(state.context, [state.pendingHelloCommand])
271+
return .failPendingCommandsAndClose(state.context, state.pendingCommands)
259272
} else {
260273
self = .connected(state)
261-
return .reschedule(state.pendingHelloCommand.deadline)
274+
return .reschedule(helloCommand.deadline)
262275
}
263276
case .active(let state):
264277
if let firstCommand = state.pendingCommands.first {
@@ -303,11 +316,11 @@ extension ValkeyChannelHandler {
303316
case .initialized:
304317
preconditionFailure("Cannot cancel when initialized")
305318
case .connected(let state):
306-
if let command = state.cancel(requestID: requestID) {
319+
if let commands = state.cancel(requestID: requestID) {
307320
self = .closed(CancellationError())
308321
return .failPendingCommandsAndClose(
309322
state.context,
310-
cancel: [command],
323+
cancel: .init(commands),
311324
closeConnectionDueToCancel: []
312325
)
313326
} else {
@@ -360,7 +373,7 @@ extension ValkeyChannelHandler {
360373
self = .closed(nil)
361374
return .doNothing
362375
case .connected(let state):
363-
self = .closing(.init(context: state.context, pendingCommands: [state.pendingHelloCommand]))
376+
self = .closing(.init(context: state.context, pendingCommands: state.pendingCommands))
364377
return .waitForPendingCommands(state.context)
365378
case .active(let state):
366379
if state.pendingCommands.count > 0 {
@@ -393,7 +406,7 @@ extension ValkeyChannelHandler {
393406
return .doNothing
394407
case .connected(let state):
395408
self = .closed(nil)
396-
return .failPendingCommandsAndClose(state.context, [state.pendingHelloCommand])
409+
return .failPendingCommandsAndClose(state.context, state.pendingCommands)
397410
case .active(let state):
398411
self = .closed(nil)
399412
return .failPendingCommandsAndClose(state.context, state.pendingCommands)
@@ -421,7 +434,7 @@ extension ValkeyChannelHandler {
421434
return .doNothing
422435
case .connected(let state):
423436
self = .closed(nil)
424-
return .failPendingCommandsAndSubscriptions([state.pendingHelloCommand])
437+
return .failPendingCommandsAndSubscriptions(state.pendingCommands)
425438
case .active(let state):
426439
self = .closed(nil)
427440
return .failPendingCommandsAndSubscriptions(state.pendingCommands)

Sources/Valkey/Connection/ValkeyChannelHandler.swift

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@ import NIOCore
2020
enum ValkeyPromise<T: Sendable>: Sendable {
2121
case nio(EventLoopPromise<T>)
2222
case swift(CheckedContinuation<T, any Error>)
23+
case ignore
2324

2425
func succeed(_ t: T) {
2526
switch self {
2627
case .nio(let eventLoopPromise):
2728
eventLoopPromise.succeed(t)
2829
case .swift(let checkedContinuation):
2930
checkedContinuation.resume(returning: t)
31+
case .ignore:
32+
break
3033
}
3134
}
3235

@@ -36,6 +39,8 @@ enum ValkeyPromise<T: Sendable>: Sendable {
3639
eventLoopPromise.fail(e)
3740
case .swift(let checkedContinuation):
3841
checkedContinuation.resume(throwing: e)
42+
case .ignore:
43+
break
3944
}
4045
}
4146
}
@@ -275,25 +280,35 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
275280
@usableFromInline
276281
func setConnected(context: ChannelHandlerContext) {
277282
// Send initial HELLO command
278-
let command = HELLO(
283+
let helloCommand = HELLO(
279284
arguments: .init(
280285
protover: 3,
281286
auth: configuration.authentication.map { .init(username: $0.username, password: $0.password) },
282287
clientname: configuration.clientName
283288
)
284289
)
290+
// set client info
291+
let clientInfoLibName = CLIENT.SETINFO(attr: .libname("valkey-swift"))
292+
let clientInfoLibVersion = CLIENT.SETINFO(attr: .libver("0.1.0"))
293+
285294
self.encoder.reset()
286-
command.encode(into: &self.encoder)
287-
let buffer = self.encoder.buffer
295+
helloCommand.encode(into: &self.encoder)
296+
clientInfoLibName.encode(into: &self.encoder)
297+
clientInfoLibVersion.encode(into: &self.encoder)
288298

289299
let promise = eventLoop.makePromise(of: RESPToken.self)
300+
290301
let deadline = .now() + self.configuration.commandTimeout
291-
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
302+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.buffer), promise: nil)
292303
scheduleDeadlineCallback(deadline: deadline)
293304

294305
self.stateMachine.setConnected(
295306
context: context,
296-
pendingHelloCommand: .init(promise: .nio(promise), requestID: 0, deadline: deadline)
307+
pendingCommands: [
308+
.init(promise: .nio(promise), requestID: 0, deadline: deadline), // HELLO
309+
.init(promise: .ignore, requestID: 0, deadline: deadline), // CLIENT.SETINFO
310+
.init(promise: .ignore, requestID: 0, deadline: deadline), // CLIENT.SETINFO
311+
]
297312
)
298313
}
299314

Tests/ValkeyTests/Utils/NIOAsyncTestingChannel+hello.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ import Testing
2121
extension NIOAsyncTestingChannel {
2222
func processHello() async throws {
2323
let hello = try await self.waitForOutboundWrite(as: ByteBuffer.self)
24-
#expect(hello == RESPToken(.array([.bulkString("HELLO"), .bulkString("3")])).base)
24+
var expectedBuffer = ByteBuffer()
25+
expectedBuffer.writeImmutableBuffer(RESPToken(.array([.bulkString("HELLO"), .bulkString("3")])).base)
26+
expectedBuffer.writeImmutableBuffer(
27+
RESPToken(.array([.bulkString("CLIENT"), .bulkString("SETINFO"), .bulkString("lib-name"), .bulkString("valkey-swift")])).base
28+
)
29+
expectedBuffer.writeImmutableBuffer(
30+
RESPToken(.array([.bulkString("CLIENT"), .bulkString("SETINFO"), .bulkString("lib-ver"), .bulkString("0.1.0")])).base
31+
)
32+
#expect(hello == expectedBuffer)
2533
try await self.writeInbound(
2634
RESPToken(
2735
.map([
@@ -35,5 +43,7 @@ extension NIOAsyncTestingChannel {
3543
])
3644
).base
3745
)
46+
try await self.writeInbound(RESPToken.ok.base)
47+
try await self.writeInbound(RESPToken.ok.base)
3848
}
3949
}

Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,12 @@ struct ValkeyChannelHandlerStateMachineTests {
183183
}
184184
expect(
185185
stateMachine.state
186-
== .closing(.init(context: "testGracefulShutdown", pendingCommands: [.init(promise: .nio(promise), requestID: 23, deadline: .now())]))
186+
== .closing(
187+
.init(
188+
context: "testGracefulShutdown",
189+
pendingCommands: [.init(promise: .nio(promise), requestID: 23, deadline: .now())]
190+
)
191+
)
187192
)
188193
switch stateMachine.receivedResponse(token: .ok) {
189194
case .respondAndClose(let command, let error):
@@ -218,7 +223,10 @@ struct ValkeyChannelHandlerStateMachineTests {
218223
expect(
219224
stateMachine.state
220225
== .closing(
221-
.init(context: "testClosedClosingState", pendingCommands: [.init(promise: .nio(promise), requestID: 17, deadline: .now())])
226+
.init(
227+
context: "testClosedClosingState",
228+
pendingCommands: [.init(promise: .nio(promise), requestID: 17, deadline: .now())]
229+
)
222230
)
223231
)
224232
switch stateMachine.setClosed() {
@@ -460,7 +468,7 @@ extension ValkeyChannelHandler.StateMachine<String>.State {
460468
case .connected(let lhs):
461469
switch rhs {
462470
case .connected(let rhs):
463-
return lhs.context == rhs.context && lhs.pendingHelloCommand.requestID == rhs.pendingHelloCommand.requestID
471+
return lhs.context == rhs.context && lhs.pendingCommands.map { $0.requestID } == rhs.pendingCommands.map { $0.requestID }
464472
default:
465473
return false
466474
}
@@ -535,7 +543,7 @@ extension ValkeyChannelHandler.StateMachine {
535543
let promise = EmbeddedEventLoop().makePromise(of: RESPToken.self)
536544
self.setConnected(
537545
context: context,
538-
pendingHelloCommand: .init(promise: .nio(promise), requestID: 0, deadline: .now() + .seconds(30))
546+
pendingCommands: [.init(promise: .nio(promise), requestID: 0, deadline: .now() + .seconds(30))]
539547
)
540548
}
541549

Tests/ValkeyTests/ValkeyConnectionTests.swift

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ struct ConnectionTests {
4646
let logger = Logger(label: "test")
4747
_ = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger)
4848

49-
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
50-
#expect(outbound == RESPToken(.command(["HELLO", "3"])).base)
49+
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
50+
let hello3 = RESPToken(.command(["HELLO", "3"])).base
51+
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
5152
}
5253

5354
@Test
@@ -57,8 +58,9 @@ struct ConnectionTests {
5758
let logger = Logger(label: "test")
5859
_ = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger)
5960

60-
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
61-
#expect(outbound == RESPToken(.command(["HELLO", "3"])).base)
61+
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
62+
let hello3 = RESPToken(.command(["HELLO", "3"])).base
63+
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
6264
await #expect(throws: ValkeyClientError(.commandError, message: "Not supported")) {
6365
try await channel.writeInbound(RESPToken(.bulkError("Not supported")).base)
6466
}
@@ -79,8 +81,9 @@ struct ConnectionTests {
7981
logger: logger
8082
)
8183

82-
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
83-
#expect(outbound == RESPToken(.command(["HELLO", "3", "AUTH", "john", "smith"])).base)
84+
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
85+
let hello3 = RESPToken(.command(["HELLO", "3", "AUTH", "john", "smith"])).base
86+
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
8487
}
8588

8689
@Test
@@ -95,8 +98,9 @@ struct ConnectionTests {
9598
logger: logger
9699
)
97100

98-
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
99-
#expect(outbound == RESPToken(.command(["HELLO", "3", "SETNAME", "Testing"])).base)
101+
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
102+
let hello3 = RESPToken(.command(["HELLO", "3", "SETNAME", "Testing"])).base
103+
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
100104
}
101105

102106
@Test

0 commit comments

Comments
 (0)