Skip to content

Commit 71b93f1

Browse files
committed
Address review comments
1 parent 7281da0 commit 71b93f1

File tree

3 files changed

+44
-57
lines changed

3 files changed

+44
-57
lines changed

Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
22
public struct PostgresCopyFromWriter: Sendable {
3-
/// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed.
4-
///
5-
/// The `PostgresCopyFromWriter` should cancel the data transfer.
6-
public struct CopyCancellationError: Error {
7-
/// The error that the backend sent us which cancelled the data transfer.
8-
///
9-
/// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before
10-
/// new data is written by `write`.
11-
public let underlyingError: PSQLError
12-
}
13-
143
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
154
private let eventLoop: any EventLoop
165

@@ -42,9 +31,9 @@ public struct PostgresCopyFromWriter: Sendable {
4231

4332
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
4433
///
45-
/// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws
46-
/// a `CopyCancellationError`.
47-
public func write(_ byteBuffer: ByteBuffer) async throws {
34+
/// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy
35+
/// operation, eg. to indicate that a **previous** `write` call had an invalid format.
36+
public func write(_ byteBuffer: ByteBuffer, isolation: isolated (any Actor)? = #isolation) async throws {
4837
// Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the
4938
// `writeData` closure. It is likely that the user would forget to do so.
5039
try Task.checkCancellation()
@@ -82,7 +71,7 @@ public struct PostgresCopyFromWriter: Sendable {
8271

8372
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
8473
/// the backend.
85-
func done() async throws {
74+
func done(isolation: isolated (any Actor)? = #isolation) async throws {
8675
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
8776
if eventLoop.inEventLoop {
8877
self.channelHandler.value.sendCopyDone(continuation: continuation)
@@ -96,37 +85,43 @@ public struct PostgresCopyFromWriter: Sendable {
9685

9786
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
9887
/// the backend.
99-
func failed(error: any Error) async throws {
88+
func failed(error: any Error, isolation: isolated (any Actor)? = #isolation) async throws {
10089
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
101-
// TODO: Is it OK to use string interpolation to construct an error description to be sent to the backend
102-
// here? We could also use a generic description, it doesn't really matter since we throw the user's error
103-
// in `copyFrom`.
10490
if eventLoop.inEventLoop {
105-
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
91+
self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation)
10692
} else {
10793
eventLoop.execute {
108-
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
94+
self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation)
10995
}
11096
}
11197
}
11298
}
11399
}
114100

115101
/// Specifies the format in which data is transferred to the backend in a COPY operation.
116-
public enum PostgresCopyFromFormat: Sendable {
102+
///
103+
/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings
104+
/// and their default values.
105+
public struct PostgresCopyFromFormat: Sendable {
117106
/// Options that can be used to modify the `text` format of a COPY operation.
118107
public struct TextOptions: Sendable {
119108
/// The delimiter that separates columns in the data.
120109
///
121110
/// See the `DELIMITER` option in Postgres's `COPY` command.
122-
///
123-
/// Uses the default delimiter of the format
124111
public var delimiter: UnicodeScalar? = nil
125112

126113
public init() {}
127114
}
128115

129-
case text(TextOptions)
116+
enum Format {
117+
case text(TextOptions)
118+
}
119+
120+
var format: Format
121+
122+
public static func text(_ options: TextOptions) -> PostgresCopyFromFormat {
123+
return PostgresCopyFromFormat(format: .text(options))
124+
}
130125
}
131126

132127
/// Create a `COPY ... FROM STDIN` query based on the given parameters.
@@ -138,14 +133,17 @@ private func buildCopyFromQuery(
138133
columns: [StaticString] = [],
139134
format: PostgresCopyFromFormat
140135
) -> PostgresQuery {
141-
// TODO: Should we put the table and column names in quotes to make them case-sensitive?
142-
var query = "COPY \(table)"
136+
var query = """
137+
COPY "\(table)"
138+
"""
143139
if !columns.isEmpty {
144-
query += "(" + columns.map(\.description).joined(separator: ",") + ")"
140+
query += "("
141+
query += columns.map { #"""# + $0.description + #"""# }.joined(separator: ",")
142+
query += ")"
145143
}
146144
query += " FROM STDIN"
147145
var queryOptions: [String] = []
148-
switch format {
146+
switch format.format {
149147
case .text(let options):
150148
queryOptions.append("FORMAT text")
151149
if let delimiter = options.delimiter {
@@ -179,6 +177,7 @@ extension PostgresConnection {
179177
columns: [StaticString] = [],
180178
format: PostgresCopyFromFormat = .text(.init()),
181179
logger: Logger,
180+
isolation: isolated (any Actor)? = #isolation,
182181
file: String = #fileID,
183182
line: Int = #line,
184183
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void
@@ -205,22 +204,13 @@ extension PostgresConnection {
205204
// threw instead of the one that got relayed back, so it's better to ignore the error here.
206205
// - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts
207206
// the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger
208-
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the `CopyCancellationError`
209-
// from the `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it
210-
// doesn't matter that we ignore the error here. If the user threw some other error, it's better to honor
211-
// the user's error.
207+
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the error from from the
208+
// `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it doesn't
209+
// matter that we ignore the error here. If the user threw some other error, it's better to honor the
210+
// user's error.
212211
try? await writer.failed(error: error)
213212

214-
if let error = error as? PostgresCopyFromWriter.CopyCancellationError {
215-
// If we receive a `CopyCancellationError` that is with almost certain likelihood because
216-
// `PostgresCopyFromWriter.write` threw it - otherwise the user must have saved a previous
217-
// `PostgresCopyFromWriter` error, which is very unlikely.
218-
// Throw the underlying error because that contains the error message that was sent by the backend and
219-
// is most actionable by the user.
220-
throw error.underlyingError
221-
} else {
222-
throw error
223-
}
213+
throw error
224214
}
225215

226216
// `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during
@@ -230,5 +220,4 @@ extension PostgresConnection {
230220
// above.
231221
try await writer.done()
232222
}
233-
234223
}

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ struct ExtendedQueryStateMachine {
427427
if case .error(let error) = self.state {
428428
// The backend sent us an ErrorResponse during the copy operation. Indicate to the client that it should
429429
// abort the data transfer.
430-
promise.fail(PostgresCopyFromWriter.CopyCancellationError(underlyingError: error))
430+
promise.fail(error)
431431
return
432432
}
433433
guard case .copyingData(.readyToSend) = self.state else {

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ import Synchronization
660660
try await expectCopyFrom { writer in
661661
try await writer.write(ByteBuffer(staticString: "1\tAlice\n"))
662662
} validateCopyRequest: { copyRequest in
663-
#expect(copyRequest.parse.query == "COPY copy_table(id,name) FROM STDIN WITH (FORMAT text)")
663+
#expect(copyRequest.parse.query == #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text)"#)
664664
#expect(copyRequest.bind.parameters == [])
665665
} mockBackend: { channel, _ in
666666
let data = try await channel.waitForCopyData()
@@ -677,7 +677,7 @@ import Synchronization
677677
try await expectCopyFrom(format: .text(options)) { writer in
678678
try await writer.write(ByteBuffer(staticString: "1,Alice\n"))
679679
} validateCopyRequest: { copyRequest in
680-
#expect(copyRequest.parse.query == #"COPY copy_table(id,name) FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#)
680+
#expect(copyRequest.parse.query == #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#)
681681
#expect(copyRequest.bind.parameters == [])
682682
} mockBackend: { channel, _ in
683683
let data = try await channel.waitForCopyData()
@@ -688,19 +688,17 @@ import Synchronization
688688
}
689689

690690
@Test func testCopyFromWriterFails() async throws {
691-
struct MyError: Error, CustomStringConvertible {
692-
var description: String { "My error" }
693-
}
691+
struct MyError: Error {}
694692

695693
try await expectCopyFrom { writer in
696694
throw MyError()
697695
} validateCopyFromError: { error in
698696
#expect(error is MyError, "Expected error of type MyError, got \(error)")
699697
} mockBackend: { channel, _ in
700698
let data = try await channel.waitForCopyData()
701-
#expect(data.result == .failed(message: "My error"))
699+
#expect(data.result == .failed(message: "Client failed copy"))
702700
try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [
703-
.message: "COPY from stdin failed: My error",
701+
.message: "COPY from stdin failed: Client failed copy",
704702
.sqlState : "57014" // query_canceled
705703
])))
706704
}
@@ -783,7 +781,7 @@ import Synchronization
783781
try await writer.write(ByteBuffer(staticString: "2\tBob\n"))
784782
Issue.record("Expected error to be thrown")
785783
} catch {
786-
#expect(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)")
784+
#expect((error as? PSQLError)?.serverInfo?[.sqlState] == "22P02")
787785
throw error
788786
}
789787
} validateCopyFromError: { error in
@@ -800,7 +798,7 @@ import Synchronization
800798
}
801799
}
802800

803-
@Test func testCopyFromCallerDoesNotRethrowCopyCancellationError() async throws {
801+
@Test func testCopyFromCallerDoesNotRethrowFromWriteCall() async throws {
804802
struct MyError: Error, CustomStringConvertible {
805803
var description: String { "My error" }
806804
}
@@ -816,7 +814,7 @@ import Synchronization
816814
try await writer.write(ByteBuffer(staticString: "2\tBob\n"))
817815
Issue.record("Expected error to be thrown")
818816
} catch {
819-
#expect(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)")
817+
#expect((error as? PSQLError)?.serverInfo?[.sqlState] == "22P02")
820818
throw MyError()
821819
}
822820
} validateCopyFromError: { error in
@@ -903,10 +901,10 @@ import Synchronization
903901
cancelCopy()
904902

905903
let data = try await channel.waitForCopyData()
906-
#expect(data.result == .failed(message: "CancellationError()"))
904+
#expect(data.result == .failed(message: "Client failed copy"))
907905

908906
try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [
909-
.message: "COPY from stdin failed: CancellationError()",
907+
.message: "COPY from stdin failed: Client failed copy",
910908
.sqlState : "57014" // query_canceled
911909
])))
912910
}

0 commit comments

Comments
 (0)