Skip to content

Move SETINFO calls to ChannelHandler so they can be written along with HELLO #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
43 changes: 28 additions & 15 deletions Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ extension ValkeyChannelHandler {
@usableFromInline
struct ConnectedState {
let context: Context
var pendingHelloCommand: PendingCommand
let pendingHelloCommand: PendingCommand
var pendingCommands: Deque<PendingCommand>

func cancel(requestID: Int) -> PendingCommand? {
func cancel(requestID: Int) -> [PendingCommand]? {
if pendingHelloCommand.requestID == requestID {
return pendingHelloCommand
return [pendingHelloCommand] + pendingCommands
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cancel code here should be unreachable. HELLO and SETINFO are triggered from within the channel handler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's unreachable, I can add a precondition instead.

return nil
}
Expand All @@ -85,11 +86,11 @@ extension ValkeyChannelHandler {

/// handler has become active
@usableFromInline
mutating func setConnected(context: Context, pendingHelloCommand: PendingCommand) {
mutating func setConnected(context: Context, pendingHelloCommand: PendingCommand, pendingCommands: Deque<PendingCommand>) {
switch consume self.state {
case .initialized:
self = .connected(
.init(context: context, pendingHelloCommand: pendingHelloCommand)
.init(context: context, pendingHelloCommand: pendingHelloCommand, pendingCommands: pendingCommands)
)
case .connected:
preconditionFailure("Cannot set connected state when state is connected")
Expand Down Expand Up @@ -162,7 +163,7 @@ extension ValkeyChannelHandler {
self = .closed(error)
return .respondAndClose(state.pendingHelloCommand, error)
default:
self = .active(.init(context: state.context, pendingCommands: .init()))
self = .active(.init(context: state.context, pendingCommands: state.pendingCommands))
return .respond(state.pendingHelloCommand, .cancel)
}
case .active(var state):
Expand Down Expand Up @@ -204,8 +205,12 @@ extension ValkeyChannelHandler {
self = .closed(nil)
return .respondAndClose(command, nil)
}
case .closed:
preconditionFailure("Cannot receive command on closed connection")
case .closed(let error):
guard let error else {
preconditionFailure("Cannot receive command on closed connection with no error")
}
self = .closed(error)
return .closeWithError(error)
}
}

Expand Down Expand Up @@ -233,7 +238,7 @@ extension ValkeyChannelHandler {
return .done
case .closing(let state):
self = .closing(state)
return .done
return .reportedClosed(nil)
case .closed(let error):
self = .closed(error)
return .reportedClosed(error)
Expand All @@ -255,7 +260,9 @@ extension ValkeyChannelHandler {
case .connected(let state):
if state.pendingHelloCommand.deadline <= now {
self = .closed(ValkeyClientError(.timeout))
return .failPendingCommandsAndClose(state.context, [state.pendingHelloCommand])
var pendingCommands = state.pendingCommands
pendingCommands.prepend(state.pendingHelloCommand)
return .failPendingCommandsAndClose(state.context, pendingCommands)
} else {
self = .connected(state)
return .reschedule(state.pendingHelloCommand.deadline)
Expand Down Expand Up @@ -303,11 +310,11 @@ extension ValkeyChannelHandler {
case .initialized:
preconditionFailure("Cannot cancel when initialized")
case .connected(let state):
if let command = state.cancel(requestID: requestID) {
if let commands = state.cancel(requestID: requestID) {
self = .closed(CancellationError())
return .failPendingCommandsAndClose(
state.context,
cancel: [command],
cancel: .init(commands),
closeConnectionDueToCancel: []
)
} else {
Expand Down Expand Up @@ -360,7 +367,9 @@ extension ValkeyChannelHandler {
self = .closed(nil)
return .doNothing
case .connected(let state):
self = .closing(.init(context: state.context, pendingCommands: [state.pendingHelloCommand]))
var pendingCommands = state.pendingCommands
pendingCommands.prepend(state.pendingHelloCommand)
self = .closing(.init(context: state.context, pendingCommands: pendingCommands))
return .waitForPendingCommands(state.context)
case .active(let state):
if state.pendingCommands.count > 0 {
Expand Down Expand Up @@ -393,7 +402,9 @@ extension ValkeyChannelHandler {
return .doNothing
case .connected(let state):
self = .closed(nil)
return .failPendingCommandsAndClose(state.context, [state.pendingHelloCommand])
var pendingCommands = state.pendingCommands
pendingCommands.prepend(state.pendingHelloCommand)
return .failPendingCommandsAndClose(state.context, state.pendingCommands)
case .active(let state):
self = .closed(nil)
return .failPendingCommandsAndClose(state.context, state.pendingCommands)
Expand Down Expand Up @@ -421,7 +432,9 @@ extension ValkeyChannelHandler {
return .doNothing
case .connected(let state):
self = .closed(nil)
return .failPendingCommandsAndSubscriptions([state.pendingHelloCommand])
var pendingCommands = state.pendingCommands
pendingCommands.prepend(state.pendingHelloCommand)
return .failPendingCommandsAndSubscriptions(state.pendingCommands)
case .active(let state):
self = .closed(nil)
return .failPendingCommandsAndSubscriptions(state.pendingCommands)
Expand Down
47 changes: 15 additions & 32 deletions Sources/Valkey/Connection/ValkeyChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,33 +159,6 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
}
}

/// Write valkey command/commands to channel
/// - Parameters:
/// - request: Valkey command request
/// - promise: Promise to fulfill when command is complete
@inlinable
func writeAndForget<Command: ValkeyCommand>(command: Command, requestID: Int) {
self.eventLoop.assertInEventLoop()
let pendingCommand = PendingCommand(
promise: .forget,
requestID: requestID,
deadline: .now() + self.configuration.commandTimeout
)
switch self.stateMachine.sendCommand(pendingCommand) {
case .sendCommand(let context):
self.encoder.reset()
command.encode(into: &self.encoder)
let buffer = self.encoder.buffer
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
if self.deadlineCallback == nil {
self.scheduleDeadlineCallback(deadline: .now() + self.configuration.commandTimeout)
}

case .throwError:
break
}
}

@usableFromInline
func write(request: ValkeyRequest) {
self.eventLoop.assertInEventLoop()
Expand Down Expand Up @@ -307,25 +280,35 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
@usableFromInline
func setConnected(context: ChannelHandlerContext) {
// Send initial HELLO command
let command = HELLO(
let helloCommand = HELLO(
arguments: .init(
protover: 3,
auth: configuration.authentication.map { .init(username: $0.username, password: $0.password) },
clientname: configuration.clientName
)
)
// set client info
let clientInfoLibName = CLIENT.SETINFO(attr: .libname(valkeySwiftLibraryName))
let clientInfoLibVersion = CLIENT.SETINFO(attr: .libver(valkeySwiftLibraryVersion))

self.encoder.reset()
command.encode(into: &self.encoder)
let buffer = self.encoder.buffer
helloCommand.encode(into: &self.encoder)
clientInfoLibName.encode(into: &self.encoder)
clientInfoLibVersion.encode(into: &self.encoder)

let promise = eventLoop.makePromise(of: RESPToken.self)

let deadline = .now() + self.configuration.commandTimeout
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.buffer), promise: nil)
scheduleDeadlineCallback(deadline: deadline)

self.stateMachine.setConnected(
context: context,
pendingHelloCommand: .init(promise: .nio(promise), requestID: 0, deadline: deadline)
pendingHelloCommand: .init(promise: .nio(promise), requestID: 0, deadline: deadline),
pendingCommands: [
.init(promise: .forget, requestID: 0, deadline: deadline), // CLIENT.SETINFO libname
.init(promise: .forget, requestID: 0, deadline: deadline), // CLIENT.SETINFO libver
]
)
}

Expand Down
10 changes: 2 additions & 8 deletions Sources/Valkey/Connection/ValkeyConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable {
}
}
let connection = try await future.get()
try await connection.initialHandshake()
try await connection.waitOnActive()
return connection
}

Expand All @@ -144,10 +144,8 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable {
self.channel.close(mode: .all, promise: nil)
}

func initialHandshake() async throws {
func waitOnActive() async throws {
try await self.channelHandler.waitOnActive().get()
self.executeAndForget(command: CLIENT.SETINFO(attr: .libname(valkeySwiftLibraryName)))
self.executeAndForget(command: CLIENT.SETINFO(attr: .libver(valkeySwiftLibraryVersion)))
}

/// Send RESP command to Valkey connection
Expand All @@ -174,10 +172,6 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable {
}
}

func executeAndForget<Command: ValkeyCommand>(command: Command) {
self.channelHandler.writeAndForget(command: command, requestID: Self.requestIDGenerator.next())
}

/// Pipeline a series of commands to Valkey connection
///
/// Once all the responses for the commands have been received the function returns
Expand Down
2 changes: 1 addition & 1 deletion Sources/Valkey/ValkeyConnectionFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ package final class ValkeyConnectionFactory: Sendable {
logger: logger
)
}.get()
try await connection.initialHandshake()
try await connection.waitOnActive()
return connection
}
}
Expand Down
2 changes: 2 additions & 0 deletions Sources/Valkey/Version.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
//
//===----------------------------------------------------------------------===//

/// library name reported to server using CLIENT SETINFO
package let valkeySwiftLibraryName = "valkey-swift"
/// library version reported to server using CLIENT SETINFO
package let valkeySwiftLibraryVersion = "0.1.0"
12 changes: 11 additions & 1 deletion Tests/ValkeyTests/Utils/NIOAsyncTestingChannel+hello.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@ import Testing
extension NIOAsyncTestingChannel {
func processHello() async throws {
let hello = try await self.waitForOutboundWrite(as: ByteBuffer.self)
#expect(hello == RESPToken(.array([.bulkString("HELLO"), .bulkString("3")])).base)
var expectedBuffer = ByteBuffer()
expectedBuffer.writeImmutableBuffer(RESPToken(.array([.bulkString("HELLO"), .bulkString("3")])).base)
expectedBuffer.writeImmutableBuffer(
RESPToken(.array([.bulkString("CLIENT"), .bulkString("SETINFO"), .bulkString("lib-name"), .bulkString(valkeySwiftLibraryName)])).base
)
expectedBuffer.writeImmutableBuffer(
RESPToken(.array([.bulkString("CLIENT"), .bulkString("SETINFO"), .bulkString("lib-ver"), .bulkString(valkeySwiftLibraryVersion)])).base
)
#expect(hello == expectedBuffer)
try await self.writeInbound(
RESPToken(
.map([
Expand All @@ -35,5 +43,7 @@ extension NIOAsyncTestingChannel {
])
).base
)
try await self.writeInbound(RESPToken.ok.base)
try await self.writeInbound(RESPToken.ok.base)
}
}
17 changes: 13 additions & 4 deletions Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ struct ValkeyChannelHandlerStateMachineTests {
}
expect(
stateMachine.state
== .closing(.init(context: "testGracefulShutdown", pendingCommands: [.init(promise: .nio(promise), requestID: 23, deadline: .now())]))
== .closing(
.init(
context: "testGracefulShutdown",
pendingCommands: [.init(promise: .nio(promise), requestID: 23, deadline: .now())]
)
)
)
switch stateMachine.receivedResponse(token: .ok) {
case .respondAndClose(let command, let error):
Expand Down Expand Up @@ -218,7 +223,10 @@ struct ValkeyChannelHandlerStateMachineTests {
expect(
stateMachine.state
== .closing(
.init(context: "testClosedClosingState", pendingCommands: [.init(promise: .nio(promise), requestID: 17, deadline: .now())])
.init(
context: "testClosedClosingState",
pendingCommands: [.init(promise: .nio(promise), requestID: 17, deadline: .now())]
)
)
)
switch stateMachine.setClosed() {
Expand Down Expand Up @@ -460,7 +468,7 @@ extension ValkeyChannelHandler.StateMachine<String>.State {
case .connected(let lhs):
switch rhs {
case .connected(let rhs):
return lhs.context == rhs.context && lhs.pendingHelloCommand.requestID == rhs.pendingHelloCommand.requestID
return lhs.context == rhs.context && lhs.pendingCommands.map { $0.requestID } == rhs.pendingCommands.map { $0.requestID }
default:
return false
}
Expand Down Expand Up @@ -535,7 +543,8 @@ extension ValkeyChannelHandler.StateMachine {
let promise = EmbeddedEventLoop().makePromise(of: RESPToken.self)
self.setConnected(
context: context,
pendingHelloCommand: .init(promise: .nio(promise), requestID: 0, deadline: .now() + .seconds(30))
pendingHelloCommand: .init(promise: .nio(promise), requestID: 0, deadline: .now() + .seconds(30)),
pendingCommands: []
)
}

Expand Down
20 changes: 12 additions & 8 deletions Tests/ValkeyTests/ValkeyConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ struct ConnectionTests {
let logger = Logger(label: "test")
_ = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger)

let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
#expect(outbound == RESPToken(.command(["HELLO", "3"])).base)
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
let hello3 = RESPToken(.command(["HELLO", "3"])).base
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
}

@Test
Expand All @@ -57,8 +58,9 @@ struct ConnectionTests {
let logger = Logger(label: "test")
_ = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger)

let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
#expect(outbound == RESPToken(.command(["HELLO", "3"])).base)
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
let hello3 = RESPToken(.command(["HELLO", "3"])).base
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
await #expect(throws: ValkeyClientError(.commandError, message: "Not supported")) {
try await channel.writeInbound(RESPToken(.bulkError("Not supported")).base)
}
Expand All @@ -79,8 +81,9 @@ struct ConnectionTests {
logger: logger
)

let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
#expect(outbound == RESPToken(.command(["HELLO", "3", "AUTH", "john", "smith"])).base)
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
let hello3 = RESPToken(.command(["HELLO", "3", "AUTH", "john", "smith"])).base
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
}

@Test
Expand All @@ -95,8 +98,9 @@ struct ConnectionTests {
logger: logger
)

let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
#expect(outbound == RESPToken(.command(["HELLO", "3", "SETNAME", "Testing"])).base)
var outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
let hello3 = RESPToken(.command(["HELLO", "3", "SETNAME", "Testing"])).base
#expect(outbound.readSlice(length: hello3.readableBytes) == hello3)
}

@Test
Expand Down