Skip to content

Commit 2998b1c

Browse files
authored
New PostgresConnection connect API (#245)
1 parent 3212037 commit 2998b1c

File tree

7 files changed

+218
-96
lines changed

7 files changed

+218
-96
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 155 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,38 @@ import NIOSSL
44
import Logging
55
import NIOPosix
66

7+
/// A Postgres connection. Use it to run queries against a Postgres server.
78
public final class PostgresConnection {
8-
typealias ID = Int
9-
10-
struct Configuration {
11-
struct Authentication {
12-
var username: String
13-
var database: String? = nil
14-
var password: String? = nil
15-
16-
init(username: String, password: String?, database: String?) {
9+
/// A Postgres connection ID
10+
public typealias ID = Int
11+
12+
/// A configuration object for a connection
13+
public struct Configuration {
14+
/// A structure to configure the connection's authentication properties
15+
public struct Authentication {
16+
/// The username to connect with.
17+
///
18+
/// - Default: postgres
19+
public var username: String
20+
21+
/// The database to open on the server
22+
///
23+
/// - Default: `nil`
24+
public var database: Optional<String>
25+
26+
/// The database user's password.
27+
///
28+
/// - Default: `nil`
29+
public var password: Optional<String>
30+
31+
public init(username: String, database: String?, password: String?) {
1732
self.username = username
1833
self.database = database
1934
self.password = password
2035
}
2136
}
2237

23-
struct TLS {
38+
public struct TLS {
2439
enum Base {
2540
case disable
2641
case prefer(NIOSSLContext)
@@ -33,44 +48,50 @@ public final class PostgresConnection {
3348
self.base = base
3449
}
3550

36-
static var disable: Self = Self.init(.disable)
51+
/// Do not try to create a TLS connection to the server.
52+
public static var disable: Self = Self.init(.disable)
3753

38-
static func prefer(_ sslContext: NIOSSLContext) -> Self {
54+
/// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection.
55+
/// If the server does not support TLS, create an insecure connection.
56+
public static func prefer(_ sslContext: NIOSSLContext) -> Self {
3957
self.init(.prefer(sslContext))
4058
}
4159

42-
static func require(_ sslContext: NIOSSLContext) -> Self {
60+
/// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection.
61+
/// If the server does not support TLS, fail the connection creation.
62+
public static func require(_ sslContext: NIOSSLContext) -> Self {
4363
self.init(.require(sslContext))
4464
}
4565
}
4666

47-
enum Connection {
48-
case unresolved(host: String, port: Int)
49-
case resolved(address: SocketAddress, serverName: String?)
67+
public struct Connection {
68+
/// The server to connect to
69+
///
70+
/// - Default: localhost
71+
public var host: String
72+
73+
/// The server port to connect to.
74+
///
75+
/// - Default: 5432
76+
public var port: Int
77+
78+
public init(host: String, port: Int = 5432) {
79+
self.host = host
80+
self.port = port
81+
}
5082
}
5183

52-
var connection: Connection
84+
public var connection: Connection
5385

5486
/// The authentication properties to send to the Postgres server during startup auth handshake
55-
var authentication: Authentication?
87+
public var authentication: Authentication
5688

57-
var tls: TLS
89+
public var tls: TLS
5890

59-
init(host: String,
60-
port: Int = 5432,
61-
username: String,
62-
database: String? = nil,
63-
password: String? = nil,
64-
tls: TLS = .disable
65-
) {
66-
self.connection = .unresolved(host: host, port: port)
67-
self.authentication = Authentication(username: username, password: password, database: database)
68-
self.tls = tls
69-
}
70-
71-
init(connection: Connection,
72-
authentication: Authentication?,
73-
tls: TLS
91+
public init(
92+
connection: Connection,
93+
authentication: Authentication,
94+
tls: TLS
7495
) {
7596
self.connection = connection
7697
self.authentication = authentication
@@ -129,7 +150,7 @@ public final class PostgresConnection {
129150
assert(self.isClosed, "PostgresConnection deinitialized before being closed.")
130151
}
131152

132-
func start(configuration: Configuration) -> EventLoopFuture<Void> {
153+
func start(configuration: InternalConfiguration) -> EventLoopFuture<Void> {
133154
// 1. configure handlers
134155

135156
var configureSSLCallback: ((Channel) throws -> ())? = nil
@@ -186,9 +207,32 @@ public final class PostgresConnection {
186207
}
187208
}
188209

210+
/// Create a new connection to a Postgres server
211+
///
212+
/// - Parameters:
213+
/// - eventLoop: The `EventLoop` the request shall be created on
214+
/// - configuration: A ``Configuration`` that shall be used for the connection
215+
/// - connectionID: An `Int` id, used for metadata logging
216+
/// - logger: A logger to log background events into
217+
/// - Returns: A SwiftNIO `EventLoopFuture` that will provide a ``PostgresConnection``
218+
/// at a later point in time.
219+
public static func connect(
220+
on eventLoop: EventLoop,
221+
configuration: PostgresConnection.Configuration,
222+
id connectionID: ID,
223+
logger: Logger
224+
) -> EventLoopFuture<PostgresConnection> {
225+
self.connect(
226+
connectionID: connectionID,
227+
configuration: .init(configuration),
228+
logger: logger,
229+
on: eventLoop
230+
)
231+
}
232+
189233
static func connect(
190234
connectionID: ID,
191-
configuration: PostgresConnection.Configuration,
235+
configuration: PostgresConnection.InternalConfiguration,
192236
logger: Logger,
193237
on eventLoop: EventLoop
194238
) -> EventLoopFuture<PostgresConnection> {
@@ -286,6 +330,9 @@ public final class PostgresConnection {
286330
}
287331

288332

333+
/// Closes the connection to the server.
334+
///
335+
/// - Returns: An EventLoopFuture that is succeeded once the connection is closed.
289336
public func close() -> EventLoopFuture<Void> {
290337
guard !self.isClosed else {
291338
return self.eventLoop.makeSucceededFuture(())
@@ -301,6 +348,10 @@ public final class PostgresConnection {
301348
extension PostgresConnection {
302349
static let idGenerator = NIOAtomic.makeAtomic(value: 0)
303350

351+
@available(*, deprecated,
352+
message: "Use the new connect method that allows you to connect and authenticate in a single step",
353+
renamed: "connect(on:configuration:id:logger:)"
354+
)
304355
public static func connect(
305356
to socketAddress: SocketAddress,
306357
tlsConfiguration: TLSConfiguration? = nil,
@@ -319,7 +370,7 @@ extension PostgresConnection {
319370
}
320371

321372
return tlsFuture.flatMap { tls in
322-
let configuration = PostgresConnection.Configuration(
373+
let configuration = PostgresConnection.InternalConfiguration(
323374
connection: .resolved(address: socketAddress, serverName: serverHostname),
324375
authentication: nil,
325376
tls: tls
@@ -336,6 +387,10 @@ extension PostgresConnection {
336387
}
337388
}
338389

390+
@available(*, deprecated,
391+
message: "Use the new connect method that allows you to connect and authenticate in a single step",
392+
renamed: "connect(on:configuration:id:logger:)"
393+
)
339394
public func authenticate(
340395
username: String,
341396
database: String? = nil,
@@ -359,28 +414,50 @@ extension PostgresConnection {
359414

360415
#if swift(>=5.5) && canImport(_Concurrency)
361416
extension PostgresConnection {
362-
func close() async throws {
417+
418+
/// Creates a new connection to a Postgres server.
419+
///
420+
/// - Parameters:
421+
/// - eventLoop: The `EventLoop` the request shall be created on
422+
/// - configuration: A ``Configuration`` that shall be used for the connection
423+
/// - connectionID: An `Int` id, used for metadata logging
424+
/// - logger: A logger to log background events into
425+
/// - Returns: An established ``PostgresConnection`` asynchronously that can be used to run queries.
426+
public static func connect(
427+
on eventLoop: EventLoop,
428+
configuration: PostgresConnection.Configuration,
429+
id connectionID: ID,
430+
logger: Logger
431+
) async throws -> PostgresConnection {
432+
try await self.connect(
433+
connectionID: connectionID,
434+
configuration: .init(configuration),
435+
logger: logger,
436+
on: eventLoop
437+
).get()
438+
}
439+
440+
/// Closes the connection to the server.
441+
public func close() async throws {
363442
try await self.close().get()
364443
}
365444

366445
func query(_ query: PostgresQuery, logger: Logger, file: String = #file, line: UInt = #line) async throws -> PostgresRowSequence {
367446
var logger = logger
368447
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
369448

370-
do {
371-
guard query.binds.count <= Int(Int16.max) else {
372-
throw PSQLError.tooManyParameters
373-
}
374-
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
375-
let context = ExtendedQueryContext(
376-
query: query,
377-
logger: logger,
378-
promise: promise)
449+
guard query.binds.count <= Int(Int16.max) else {
450+
throw PSQLError.tooManyParameters
451+
}
452+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
453+
let context = ExtendedQueryContext(
454+
query: query,
455+
logger: logger,
456+
promise: promise)
379457

380-
self.channel.write(PSQLTask.extendedQuery(context), promise: nil)
458+
self.channel.write(PSQLTask.extendedQuery(context), promise: nil)
381459

382-
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
383-
}
460+
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
384461
}
385462
}
386463
#endif
@@ -539,7 +616,7 @@ enum CloseTarget {
539616
case portal(String)
540617
}
541618

542-
extension PostgresConnection.Configuration {
619+
extension PostgresConnection.InternalConfiguration {
543620
var sslServerHostname: String? {
544621
switch self.connection {
545622
case .unresolved(let host, _):
@@ -567,6 +644,33 @@ private extension String {
567644
}
568645
}
569646

647+
extension PostgresConnection {
648+
/// A configuration object to bring the new ``PostgresConnection.Configuration`` together with
649+
/// the deprecated configuration.
650+
///
651+
/// TODO: Drop with next major release
652+
struct InternalConfiguration {
653+
enum Connection {
654+
case unresolved(host: String, port: Int)
655+
case resolved(address: SocketAddress, serverName: String?)
656+
}
657+
658+
var connection: Connection
659+
660+
var authentication: Configuration.Authentication?
661+
662+
var tls: Configuration.TLS
663+
}
664+
}
665+
666+
extension PostgresConnection.InternalConfiguration {
667+
init(_ config: PostgresConnection.Configuration) {
668+
self.authentication = config.authentication
669+
self.connection = .unresolved(host: config.connection.host, port: config.connection.port)
670+
self.tls = config.tls
671+
}
672+
}
673+
570674
#if swift(>=5.6)
571675
extension PostgresConnection: @unchecked Sendable {}
572676
#endif

Sources/PostgresNIO/New/PSQLChannelHandler.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
2626
private var rowStream: PSQLRowStream?
2727
private var decoder: NIOSingleStepByteToMessageProcessor<PSQLBackendMessageDecoder>
2828
private var encoder: BufferedMessageEncoder!
29-
private let configuration: PostgresConnection.Configuration
29+
private let configuration: PostgresConnection.InternalConfiguration
3030
private let configureSSLCallback: ((Channel) throws -> Void)?
3131

3232
/// this delegate should only be accessed on the connections `EventLoop`
3333
weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate?
3434

35-
init(configuration: PostgresConnection.Configuration,
35+
init(configuration: PostgresConnection.InternalConfiguration,
3636
logger: Logger,
3737
configureSSLCallback: ((Channel) throws -> Void)?)
3838
{
@@ -45,7 +45,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
4545

4646
#if DEBUG
4747
/// for testing purposes only
48-
init(configuration: PostgresConnection.Configuration,
48+
init(configuration: PostgresConnection.InternalConfiguration,
4949
state: ConnectionStateMachine = .init(.initialized),
5050
logger: Logger = .psqlNoOpLogger,
5151
configureSSLCallback: ((Channel) throws -> Void)?)
@@ -575,8 +575,8 @@ private extension Insecure.MD5.Digest {
575575
}
576576

577577
extension ConnectionStateMachine.TLSConfiguration {
578-
fileprivate init(_ connection: PostgresConnection.Configuration.TLS) {
579-
switch connection.base {
578+
fileprivate init(_ tls: PostgresConnection.Configuration.TLS) {
579+
switch tls.base {
580580
case .disable:
581581
self = .disable
582582
case .require:
@@ -589,7 +589,7 @@ extension ConnectionStateMachine.TLSConfiguration {
589589

590590
extension PSQLChannelHandler {
591591
convenience init(
592-
configuration: PostgresConnection.Configuration,
592+
configuration: PostgresConnection.InternalConfiguration,
593593
configureSSLCallback: ((Channel) throws -> Void)?)
594594
{
595595
self.init(

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,17 @@ final class IntegrationTests: XCTestCase {
2323
try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust")
2424

2525
let config = PostgresConnection.Configuration(
26-
host: env("POSTGRES_HOSTNAME") ?? "localhost",
27-
port: 5432,
28-
username: env("POSTGRES_USER") ?? "test_username",
29-
database: env("POSTGRES_DB") ?? "test_database",
30-
password: "wrong_password",
31-
tls: .disable)
26+
connection: .init(
27+
host: env("POSTGRES_HOSTNAME") ?? "localhost",
28+
port: 5432
29+
),
30+
authentication: .init(
31+
username: env("POSTGRES_USER") ?? "test_username",
32+
database: env("POSTGRES_DB") ?? "test_database",
33+
password: "wrong_password"
34+
),
35+
tls: .disable
36+
)
3237

3338
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
3439
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
@@ -37,7 +42,7 @@ final class IntegrationTests: XCTestCase {
3742
logger.logLevel = .info
3843

3944
var connection: PostgresConnection?
40-
XCTAssertThrowsError(connection = try PostgresConnection.connect(connectionID: 1, configuration: config, logger: logger, on: eventLoopGroup.next()).wait()) {
45+
XCTAssertThrowsError(connection = try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) {
4146
XCTAssertTrue($0 is PSQLError)
4247
}
4348

0 commit comments

Comments
 (0)