Skip to content

Commit 89a29d4

Browse files
authored
Add support for usernames (#72)
Redis 6.0 adds the ability to specify a username when sending an `AUTH` command. This patch adds this capability to RediStack.
1 parent 3001e41 commit 89a29d4

File tree

5 files changed

+176
-30
lines changed

5 files changed

+176
-30
lines changed

Sources/RediStack/Commands/BasicCommands.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ extension RedisClient {
8888
.map { _ in return () }
8989
}
9090

91+
/// Requests the client to authenticate with Redis to allow other commands to be executed.
92+
/// - Parameters:
93+
/// - username: The username to authenticate with.
94+
/// - password: The password to authenticate with.
95+
/// Warning: This function should only be used if you are running against Redis 6 or higher.
96+
public func authorize(
97+
username: String,
98+
password: String
99+
) -> EventLoopFuture<Void> {
100+
let args = [RESPValue(from: username), RESPValue(from: password)]
101+
return self.send(command: "AUTH", with: args).map { _ in return () }
102+
}
103+
91104
/// Removes the specified keys. A key is ignored if it does not exist.
92105
///
93106
/// [https://redis.io/commands/del](https://redis.io/commands/del)

Sources/RediStack/Configuration.swift

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ extension RedisConnection {
8080
}
8181
/// The port of the connection address. If the address is a Unix socket, then it will be `nil`.
8282
public var port: Int? { self.address.port }
83+
/// The user name used to authenticate connections with.
84+
/// - Warning: This property should only be provided if you are running against Redis 6 or higher.
85+
public let username: String?
8386
/// The password used to authenticate the connection.
8487
public let password: String?
8588
/// The initial database index that the connection should use.
@@ -91,15 +94,17 @@ extension RedisConnection {
9194

9295
/// Creates a new connection configuration with the provided details.
9396
/// - Parameters:
94-
/// - address: The socket address information to use for creating the Redis connection.
95-
/// - password: The optional password to authenticate the connection with. The default is `nil`.
96-
/// - initialDatabase: The optional database index to initially connect to. The default is `nil`.
97-
/// Redis by default opens connections against index `0`, so only set this value if the desired default is not `0`.
98-
/// - defaultLogger: The optional prototype logger to use as the default logger instance when generating logs from the connection.
97+
/// - `address`: The socket address information to use for creating the Redis connection.
98+
/// - `username`: The optional username to authenticate the connection with. The default is `nil`.
99+
/// - `password`: The optional password to authenticate the connection with. The default is `nil`.
100+
/// - `initialDatabase`: The optional database index to initially connect to. The default is `nil`.
101+
/// Redis by default opens connections against index `0`, so only set this value if the desired default is not `0`.
102+
/// - `defaultLogger`: The optional prototype logger to use as the default logger instance when generating logs from the connection.
99103
/// If one is not provided, one will be generated. See `RedisLogging.baseConnectionLogger`.
100104
/// - Throws: `RedisConnection.Configuration.ValidationError` if invalid arguments are provided.
101105
public init(
102106
address: SocketAddress,
107+
username: String? = nil,
103108
password: String? = nil,
104109
initialDatabase: Int? = nil,
105110
defaultLogger: Logger? = nil
@@ -109,11 +114,36 @@ extension RedisConnection {
109114
}
110115

111116
self.address = address
117+
self.username = username
112118
self.password = password
113119
self.initialDatabase = initialDatabase
114120
self.defaultLogger = defaultLogger ?? Configuration.defaultLogger
115121
}
116122

123+
/// Creates a new connection configuration with the provided details.
124+
/// - Parameters:
125+
/// - `address`: The socket address information to use for creating the Redis connection.
126+
/// - `password`: The optional password to authenticate the connection with. The default is `nil`.
127+
/// - `initialDatabase`: The optional database index to initially connect to. The default is `nil`.
128+
/// Redis by default opens connections against index `0`, so only set this value if the desired default is not `0`.
129+
/// - `defaultLogger`: The optional prototype logger to use as the default logger instance when generating logs from the connection.
130+
/// If one is not provided, one will be generated. See `RedisLogging.baseConnectionLogger`.
131+
/// - Throws: `RedisConnection.Configuration.ValidationError` if invalid arguments are provided.
132+
public init(
133+
address: SocketAddress,
134+
password: String? = nil,
135+
initialDatabase: Int? = nil,
136+
defaultLogger: Logger? = nil
137+
) throws {
138+
try self.init(
139+
address: address,
140+
username: nil,
141+
password: password,
142+
initialDatabase: initialDatabase,
143+
defaultLogger: defaultLogger
144+
)
145+
}
146+
117147
/// Creates a new connection configuration with exact details.
118148
/// - Parameters:
119149
/// - hostname: The remote hostname to connect to.
@@ -214,6 +244,9 @@ extension RedisConnectionPool {
214244
// this needs to be var so it can be updated by the pool with the pool id
215245
/// The logger prototype that will be used by connections by default when generating logs.
216246
public internal(set) var connectionDefaultLogger: Logger
247+
/// The username used to authenticate connections.
248+
/// - Warning: This property should only be provided if you are running against Redis 6 or higher.
249+
public let connectionUsername: String?
217250
/// The password used to authenticate connections.
218251
public let connectionPassword: String?
219252
/// The initial database index that connections should use.
@@ -224,7 +257,7 @@ extension RedisConnectionPool {
224257
/// Creates a new connection factory configuration with the provided options.
225258
/// - Parameters:
226259
/// - connectionInitialDatabase: The optional database index to initially connect to. The default is `nil`.
227-
/// Redis by default opens connections against index `0`, so only set this value if the desired default is not `0`.
260+
/// Redis by default opens connections against index `0`, so only set this value if the desired default is not `0`.
228261
/// - connectionPassword: The optional password to authenticate connections with. The default is `nil`.
229262
/// - connectionDefaultLogger: The optional prototype logger to use as the default logger instance when generating logs from connections.
230263
/// If one is not provided, one will be generated. See `RedisLogging.baseConnectionLogger`.
@@ -234,8 +267,34 @@ extension RedisConnectionPool {
234267
connectionPassword: String? = nil,
235268
connectionDefaultLogger: Logger? = nil,
236269
tcpClient: ClientBootstrap? = nil
270+
) {
271+
self.init(
272+
connectionInitialDatabase: connectionInitialDatabase,
273+
connectionUsername: nil,
274+
connectionPassword: connectionPassword,
275+
connectionDefaultLogger: connectionDefaultLogger,
276+
tcpClient: tcpClient
277+
)
278+
}
279+
280+
/// Creates a new connection factory configuration with the provided options.
281+
/// - Parameters:
282+
/// - connectionInitialDatabase: The optional database index to initially connect to. The default is `nil`.
283+
/// Redis by default opens connections against index `0`, so only set this value if the desired default is not `0`.
284+
/// - connectionUsername: The optional username to authenticate connections with. The default is `nil`. Works only with Redis 6 and greater.
285+
/// - connectionPassword: The optional password to authenticate connections with. The default is `nil`.
286+
/// - connectionDefaultLogger: The optional prototype logger to use as the default logger instance when generating logs from connections.
287+
/// If one is not provided, one will be generated. See `RedisLogging.baseConnectionLogger`.
288+
/// - tcpClient: If you have chosen to configure a `NIO.ClientBootstrap` yourself, this will be used instead of the `.makeRedisTCPClient` factory instance.
289+
public init(
290+
connectionInitialDatabase: Int? = nil,
291+
connectionUsername: String? = nil,
292+
connectionPassword: String? = nil,
293+
connectionDefaultLogger: Logger? = nil,
294+
tcpClient: ClientBootstrap? = nil
237295
) {
238296
self.connectionInitialDatabase = connectionInitialDatabase
297+
self.connectionUsername = connectionUsername
239298
self.connectionPassword = connectionPassword
240299
self.connectionDefaultLogger = connectionDefaultLogger ?? RedisConnection.Configuration.defaultLogger
241300
self.tcpClient = tcpClient

Sources/RediStack/RedisConnection.swift

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,12 @@ extension RedisConnection {
5757
) -> EventLoopFuture<RedisConnection> {
5858
let client = client ?? .makeRedisTCPClient(group: eventLoop)
5959

60-
var future = client
60+
return client
6161
.connect(to: config.address)
62-
.map { return RedisConnection(configuredRESPChannel: $0, backgroundLogger: config.defaultLogger) }
63-
64-
// if a password is specified, use it to authenticate before further operations happen
65-
if let password = config.password {
66-
future = future.flatMap { connection in
67-
return connection.authorize(with: password).map { connection }
68-
}
69-
}
70-
71-
// if a database index is specified, use it to switch the selected database before further operations happen
72-
if let database = config.initialDatabase {
73-
future = future.flatMap { connection in
74-
return connection.select(database: database).map { connection }
62+
.flatMap {
63+
let connection = RedisConnection(configuredRESPChannel: $0, backgroundLogger: config.defaultLogger)
64+
return connection.start(configuration: config).map({ _ in connection })
7565
}
76-
}
77-
78-
return future
7966
}
8067
}
8168

@@ -194,6 +181,28 @@ public final class RedisConnection: RedisClient, RedisClientWithUserContext {
194181
self.logger.trace("connection created")
195182
}
196183

184+
func start(configuration: Configuration) -> EventLoopFuture<Void> {
185+
let future: EventLoopFuture<Void>
186+
187+
// if a password is specified, use it to authenticate before further operations happen
188+
if let password = configuration.password {
189+
if let username = configuration.username {
190+
future = self.authorize(username: username, password: password)
191+
} else {
192+
future = self.authorize(with: password)
193+
}
194+
} else {
195+
future = self.eventLoop.makeSucceededVoidFuture()
196+
}
197+
198+
// if a database index is specified, use it to switch the selected database before further operations happen
199+
if let database = configuration.initialDatabase {
200+
return future.flatMap { self.select(database: database) }
201+
}
202+
203+
return future
204+
}
205+
197206
internal enum ConnectionState {
198207
case open
199208
case pubsub(RedisPubSubHandler)
@@ -223,15 +232,27 @@ extension RedisConnection {
223232
/// - Returns: A `NIO.EventLoopFuture` that resolves with the command's result stored in a `RESPValue`.
224233
/// If a `RedisError` is returned, the future will be failed instead.
225234
public func send(command: String, with arguments: [RESPValue]) -> EventLoopFuture<RESPValue> {
226-
self.eventLoop.flatSubmit {
227-
return self.send(command: command, with: arguments, logger: nil)
228-
}
235+
return self.send(command: command, with: arguments, logger: nil)
229236
}
230237

231238
internal func send(
232239
command: String,
233240
with arguments: [RESPValue],
234241
logger: Logger?
242+
) -> EventLoopFuture<RESPValue> {
243+
if self.eventLoop.inEventLoop {
244+
return self.send0(command: command, with: arguments, logger: logger)
245+
}
246+
247+
return self.eventLoop.flatSubmit {
248+
self.send0(command: command, with: arguments, logger: logger)
249+
}
250+
}
251+
252+
private func send0(
253+
command: String,
254+
with arguments: [RESPValue],
255+
logger: Logger?
235256
) -> EventLoopFuture<RESPValue> {
236257
self.eventLoop.preconditionInEventLoop()
237258

Sources/RediStack/RedisConnectionPool.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ extension RedisConnectionPool {
220220
do {
221221
connectionConfig = try .init(
222222
address: nextTarget,
223+
username: factoryConfig.connectionUsername,
223224
password: factoryConfig.connectionPassword,
224225
initialDatabase: factoryConfig.connectionInitialDatabase,
225226
defaultLogger: factoryConfig.connectionDefaultLogger

Tests/RediStackTests/RedisConnectionTests.swift

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ import XCTest
2020

2121
final class RedisConnectionTests: XCTestCase {
2222

23-
}
23+
var logger: Logger {
24+
Logger(label: "RedisConnectionTests")
25+
}
2426

25-
// MARK: Unexpected Closures
26-
extension RedisConnectionTests {
2727
func test_connectionUnexpectedlyCloses_invokesCallback() throws {
2828
let loop = EmbeddedEventLoop()
2929

3030
let expectedClosureConnection = RedisConnection(
3131
configuredRESPChannel: EmbeddedChannel(loop: loop),
32-
backgroundLogger: Logger(label: "")
32+
backgroundLogger: self.logger
3333
)
3434
let expectedClosureExpectation = self.expectation(description: "this should not be fulfilled")
3535
expectedClosureExpectation.isInverted = true
@@ -49,4 +49,56 @@ extension RedisConnectionTests {
4949

5050
self.waitForExpectations(timeout: 0.5)
5151
}
52+
53+
func testAuthorizationWithUsername() {
54+
var maybeSocketAddress: SocketAddress?
55+
XCTAssertNoThrow(maybeSocketAddress = try SocketAddress.makeAddressResolvingHost("localhost", port: 0))
56+
guard let socketAddress = maybeSocketAddress else { return XCTFail("Expected a socketAddress") }
57+
var maybeConfiguration: RedisConnection.Configuration?
58+
XCTAssertNoThrow(maybeConfiguration = try .init(address: socketAddress, username: "username", password: "password"))
59+
guard let configuration = maybeConfiguration else { return XCTFail("Expected a configuration") }
60+
61+
let channel = EmbeddedChannel(handlers: [RedisCommandHandler()])
62+
XCTAssertNoThrow(try channel.connect(to: socketAddress).wait())
63+
64+
let connection = RedisConnection(configuredRESPChannel: channel, backgroundLogger: self.logger)
65+
let future = connection.start(configuration: configuration)
66+
67+
var outgoing: RESPValue?
68+
XCTAssertNoThrow(outgoing = try channel.readOutbound(as: RESPValue.self))
69+
XCTAssertEqual(outgoing, .array([.bulkString("AUTH"), .bulkString("username"), .bulkString("password")]))
70+
XCTAssertNoThrow(try channel.writeInbound(RESPValue.simpleString("OK")))
71+
XCTAssertNoThrow(try future.wait())
72+
}
73+
74+
func testAuthorizationWithoutUsername() {
75+
var maybeSocketAddress: SocketAddress?
76+
XCTAssertNoThrow(maybeSocketAddress = try SocketAddress.makeAddressResolvingHost("localhost", port: 0))
77+
guard let socketAddress = maybeSocketAddress else { return XCTFail("Expected a socketAddress") }
78+
var maybeConfiguration: RedisConnection.Configuration?
79+
XCTAssertNoThrow(maybeConfiguration = try .init(address: socketAddress, password: "password"))
80+
guard let configuration = maybeConfiguration else { return XCTFail("Expected a configuration") }
81+
82+
let channel = EmbeddedChannel(handlers: [RedisCommandHandler()])
83+
XCTAssertNoThrow(try channel.connect(to: socketAddress).wait())
84+
85+
let connection = RedisConnection(configuredRESPChannel: channel, backgroundLogger: self.logger)
86+
let future = connection.start(configuration: configuration)
87+
88+
var outgoing: RESPValue?
89+
XCTAssertNoThrow(outgoing = try channel.readOutbound(as: RESPValue.self))
90+
XCTAssertEqual(outgoing, .array([.bulkString("AUTH"), .bulkString("password")]))
91+
XCTAssertNoThrow(try channel.writeInbound(RESPValue.simpleString("OK")))
92+
XCTAssertNoThrow(try future.wait())
93+
}
94+
}
95+
96+
extension RESPValue {
97+
static func bulkString(_ string: String) -> Self {
98+
.bulkString(ByteBuffer(string: string))
99+
}
100+
101+
static func simpleString(_ string: String) -> Self {
102+
.simpleString(ByteBuffer(string: string))
103+
}
52104
}

0 commit comments

Comments
 (0)