Skip to content

Commit 5dade1c

Browse files
authored
Explicit TLS config. (#237)
1 parent 63e7d57 commit 5dade1c

File tree

10 files changed

+161
-54
lines changed

10 files changed

+161
-54
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,29 @@ extension PostgresConnection {
5454
logger: Logger = .init(label: "codes.vapor.postgres"),
5555
on eventLoop: EventLoop
5656
) -> EventLoopFuture<PostgresConnection> {
57-
let configuration = PSQLConnection.Configuration(
58-
connection: .resolved(address: socketAddress, serverName: serverHostname),
59-
authentication: nil,
60-
tlsConfiguration: tlsConfiguration
61-
)
62-
63-
return PSQLConnection.connect(
64-
configuration: configuration,
65-
logger: logger,
66-
on: eventLoop
67-
).map { connection in
57+
var tlsFuture: EventLoopFuture<PSQLConnection.Configuration.TLS>
58+
59+
if let tlsConfiguration = tlsConfiguration {
60+
tlsFuture = eventLoop.makeSucceededVoidFuture().flatMapBlocking(onto: .global(qos: .default)) {
61+
try PSQLConnection.Configuration.TLS.require(.init(configuration: tlsConfiguration))
62+
}
63+
} else {
64+
tlsFuture = eventLoop.makeSucceededFuture(.disable)
65+
}
66+
67+
return tlsFuture.flatMap { tls in
68+
let configuration = PSQLConnection.Configuration(
69+
connection: .resolved(address: socketAddress, serverName: serverHostname),
70+
authentication: nil,
71+
tls: tls
72+
)
73+
74+
return PSQLConnection.connect(
75+
configuration: configuration,
76+
logger: logger,
77+
on: eventLoop
78+
)
79+
}.map { connection in
6880
PostgresConnection(underlying: connection, logger: logger)
6981
}.flatMapErrorThrowing { error in
7082
throw error.asAppropriatePostgresError

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ struct ConnectionStateMachine {
1818
}
1919

2020
enum State {
21+
enum TLSConfiguration {
22+
case prefer
23+
case require
24+
}
25+
2126
case initialized
22-
case sslRequestSent
27+
case sslRequestSent(TLSConfiguration)
2328
case sslNegotiated
2429
case sslHandlerAdded
2530
case waitingToStartAuthentication
@@ -114,26 +119,38 @@ struct ConnectionStateMachine {
114119
init() {
115120
self.state = .initialized
116121
}
117-
122+
118123
#if DEBUG
119124
/// for testing purposes only
120125
init(_ state: State) {
121126
self.state = state
122127
}
123128
#endif
129+
130+
enum TLSConfiguration {
131+
case disable
132+
case prefer
133+
case require
134+
}
124135

125-
mutating func connected(requireTLS: Bool) -> ConnectionAction {
136+
mutating func connected(tls: TLSConfiguration) -> ConnectionAction {
126137
guard case .initialized = self.state else {
127138
preconditionFailure("Unexpected state")
128139
}
129140

130-
if requireTLS {
131-
self.state = .sslRequestSent
141+
switch tls {
142+
case .disable:
143+
self.state = .waitingToStartAuthentication
144+
return .provideAuthenticationContext
145+
146+
case .prefer:
147+
self.state = .sslRequestSent(.prefer)
148+
return .sendSSLRequest
149+
150+
case .require:
151+
self.state = .sslRequestSent(.require)
132152
return .sendSSLRequest
133153
}
134-
135-
self.state = .waitingToStartAuthentication
136-
return .provideAuthenticationContext
137154
}
138155

139156
mutating func provideAuthenticationContext(_ authContext: AuthContext) -> ConnectionAction {
@@ -223,8 +240,12 @@ struct ConnectionStateMachine {
223240

224241
mutating func sslUnsupportedReceived() -> ConnectionAction {
225242
switch self.state {
226-
case .sslRequestSent:
243+
case .sslRequestSent(.require):
227244
return self.closeConnectionAndCleanup(.sslUnsupported)
245+
246+
case .sslRequestSent(.prefer):
247+
self.state = .waitingToStartAuthentication
248+
return .provideAuthenticationContext
228249

229250
case .initialized,
230251
.sslNegotiated,

Sources/PostgresNIO/New/PSQLChannelHandler.swift

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
5757
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
5858
}
5959
#endif
60-
60+
6161
// MARK: Handler lifecycle
6262

6363
func handlerAdded(context: ChannelHandlerContext) {
@@ -331,7 +331,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
331331
// MARK: - Private Methods -
332332

333333
private func connected(context: ChannelHandlerContext) {
334-
let action = self.state.connected(requireTLS: self.configureSSLCallback != nil)
334+
335+
let action = self.state.connected(tls: .init(self.configuration.tls))
335336

336337
self.run(action, with: context)
337338
}
@@ -572,3 +573,29 @@ private extension Insecure.MD5.Digest {
572573
return String(decoding: result, as: Unicode.UTF8.self)
573574
}
574575
}
576+
577+
extension ConnectionStateMachine.TLSConfiguration {
578+
fileprivate init(_ connection: PSQLConnection.Configuration.TLS) {
579+
switch connection.base {
580+
case .disable:
581+
self = .disable
582+
case .require:
583+
self = .require
584+
case .prefer:
585+
self = .prefer
586+
}
587+
}
588+
}
589+
590+
extension PSQLChannelHandler {
591+
convenience init(
592+
configuration: PSQLConnection.Configuration,
593+
configureSSLCallback: ((Channel) throws -> Void)?)
594+
{
595+
self.init(
596+
configuration: configuration,
597+
logger: .psqlNoOpLogger,
598+
configureSSLCallback: configureSSLCallback
599+
)
600+
}
601+
}

Sources/PostgresNIO/New/PSQLConnection.swift

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,30 @@ final class PSQLConnection {
2323
self.password = password
2424
}
2525
}
26+
27+
struct TLS {
28+
enum Base {
29+
case disable
30+
case prefer(NIOSSLContext)
31+
case require(NIOSSLContext)
32+
}
33+
34+
var base: Base
35+
36+
private init(_ base: Base) {
37+
self.base = base
38+
}
39+
40+
static var disable: Self = Self.init(.disable)
41+
42+
static func prefer(_ sslContext: NIOSSLContext) -> Self {
43+
self.init(.prefer(sslContext))
44+
}
45+
46+
static func require(_ sslContext: NIOSSLContext) -> Self {
47+
self.init(.require(sslContext))
48+
}
49+
}
2650

2751
enum Connection {
2852
case unresolved(host: String, port: Int)
@@ -34,27 +58,27 @@ final class PSQLConnection {
3458
/// The authentication properties to send to the Postgres server during startup auth handshake
3559
var authentication: Authentication?
3660

37-
var tlsConfiguration: TLSConfiguration?
61+
var tls: TLS
3862

3963
init(host: String,
4064
port: Int = 5432,
4165
username: String,
4266
database: String? = nil,
4367
password: String? = nil,
44-
tlsConfiguration: TLSConfiguration? = nil
68+
tls: TLS = .disable
4569
) {
4670
self.connection = .unresolved(host: host, port: port)
4771
self.authentication = Authentication(username: username, password: password, database: database)
48-
self.tlsConfiguration = tlsConfiguration
72+
self.tls = tls
4973
}
5074

5175
init(connection: Connection,
5276
authentication: Authentication?,
53-
tlsConfiguration: TLSConfiguration?
77+
tls: TLS
5478
) {
5579
self.connection = connection
5680
self.authentication = authentication
57-
self.tlsConfiguration = tlsConfiguration
81+
self.tls = tls
5882
}
5983
}
6084

@@ -185,14 +209,19 @@ final class PSQLConnection {
185209
let bootstrap = ClientBootstrap(group: eventLoop)
186210
.channelInitializer { channel in
187211
var configureSSLCallback: ((Channel) throws -> ())? = nil
188-
if let tlsConfiguration = configuration.tlsConfiguration {
212+
213+
switch configuration.tls.base {
214+
case .disable:
215+
break
216+
217+
case .prefer(let sslContext), .require(let sslContext):
189218
configureSSLCallback = { channel in
190219
channel.eventLoop.assertInEventLoop()
191-
192-
let sslContext = try NIOSSLContext(configuration: tlsConfiguration)
220+
193221
let sslHandler = try NIOSSLClientHandler(
194222
context: sslContext,
195-
serverHostname: configuration.sslServerHostname)
223+
serverHostname: configuration.sslServerHostname
224+
)
196225
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first)
197226
}
198227
}

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ final class IntegrationTests: XCTestCase {
2828
username: env("POSTGRES_USER") ?? "test_username",
2929
database: env("POSTGRES_DB") ?? "test_database",
3030
password: "wrong_password",
31-
tlsConfiguration: nil)
31+
tls: .disable)
3232

3333
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
3434
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
@@ -358,7 +358,8 @@ extension PSQLConnection {
358358
username: env("POSTGRES_USER") ?? "test_username",
359359
database: env("POSTGRES_DB") ?? "test_database",
360360
password: env("POSTGRES_PASSWORD") ?? "test_password",
361-
tlsConfiguration: nil)
361+
tls: .disable
362+
)
362363

363364
return PSQLConnection.connect(configuration: config, logger: logger, on: eventLoop)
364365
}

Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ class AuthenticationStateMachineTests: XCTestCase {
66

77
func testAuthenticatePlaintext() {
88
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
9-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
9+
10+
var state = ConnectionStateMachine()
11+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
1012

1113
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
1214
XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext))
@@ -15,7 +17,8 @@ class AuthenticationStateMachineTests: XCTestCase {
1517

1618
func testAuthenticateMD5() {
1719
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
18-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
20+
var state = ConnectionStateMachine()
21+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
1922
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
2023

2124
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
@@ -25,7 +28,8 @@ class AuthenticationStateMachineTests: XCTestCase {
2528

2629
func testAuthenticateMD5WithoutPassword() {
2730
let authContext = AuthContext(username: "test", password: nil, database: "test")
28-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
31+
var state = ConnectionStateMachine()
32+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
2933
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
3034

3135
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
@@ -35,15 +39,16 @@ class AuthenticationStateMachineTests: XCTestCase {
3539

3640
func testAuthenticateOkAfterStartUpWithoutAuthChallenge() {
3741
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
38-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
39-
42+
var state = ConnectionStateMachine()
43+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
4044
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
4145
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
4246
}
4347

4448
func testAuthenticationFailure() {
4549
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
46-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
50+
var state = ConnectionStateMachine()
51+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
4752
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
4853

4954
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
@@ -74,7 +79,8 @@ class AuthenticationStateMachineTests: XCTestCase {
7479

7580
for (message, mechanism) in unsupported {
7681
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
77-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
82+
var state = ConnectionStateMachine()
83+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
7884
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
7985
XCTAssertEqual(state.authenticationMessageReceived(message),
8086
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unsupportedAuthMechanism(mechanism), closePromise: nil)))
@@ -92,7 +98,8 @@ class AuthenticationStateMachineTests: XCTestCase {
9298

9399
for message in unexpected {
94100
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
95-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
101+
var state = ConnectionStateMachine()
102+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
96103
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
97104
XCTAssertEqual(state.authenticationMessageReceived(message),
98105
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil)))
@@ -118,7 +125,8 @@ class AuthenticationStateMachineTests: XCTestCase {
118125

119126
for message in unexpected {
120127
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
121-
var state = ConnectionStateMachine(.waitingToStartAuthentication)
128+
var state = ConnectionStateMachine()
129+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
122130
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
123131
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
124132
XCTAssertEqual(state.authenticationMessageReceived(message),

0 commit comments

Comments
 (0)