Skip to content

Commit bfd17ae

Browse files
authored
Increase bind parameters limit (#298)
1 parent 382b0e1 commit bfd17ae

File tree

7 files changed

+65
-12
lines changed

7 files changed

+65
-12
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ public final class PostgresConnection {
309309
private func queryStream(_ query: PostgresQuery, logger: Logger) -> EventLoopFuture<PSQLRowStream> {
310310
var logger = logger
311311
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
312-
guard query.binds.count <= Int(Int16.max) else {
312+
guard query.binds.count <= Int(UInt16.max) else {
313313
return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters)
314314
}
315315

@@ -341,7 +341,7 @@ public final class PostgresConnection {
341341
}
342342

343343
func execute(_ executeStatement: PSQLExecuteStatement, logger: Logger) -> EventLoopFuture<PSQLRowStream> {
344-
guard executeStatement.binds.count <= Int(Int16.max) else {
344+
guard executeStatement.binds.count <= Int(UInt16.max) else {
345345
return self.channel.eventLoop.makeFailedFuture(PSQLError.tooManyParameters)
346346
}
347347
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
@@ -498,7 +498,7 @@ extension PostgresConnection {
498498
var logger = logger
499499
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
500500

501-
guard query.binds.count <= Int(Int16.max) else {
501+
guard query.binds.count <= Int(UInt16.max) else {
502502
throw PSQLError.tooManyParameters
503503
}
504504
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)

Sources/PostgresNIO/New/Messages/Bind.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ extension PostgresFrontendMessage {
2020
// zero to indicate that there are no parameters or that the parameters all use the
2121
// default format (text); or one, in which case the specified format code is applied
2222
// to all parameters; or it can equal the actual number of parameters.
23-
buffer.writeInteger(Int16(self.bind.count))
23+
buffer.writeInteger(UInt16(self.bind.count))
2424

2525
// The parameter format codes. Each must presently be zero (text) or one (binary).
2626
self.bind.metadata.forEach {
2727
buffer.writeInteger($0.format.rawValue)
2828
}
2929

30-
buffer.writeInteger(Int16(self.bind.count))
30+
buffer.writeInteger(UInt16(self.bind.count))
3131

3232
var parametersCopy = self.bind.bytes
3333
buffer.writeBuffer(&parametersCopy)

Sources/PostgresNIO/New/Messages/ParameterDescription.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ extension PostgresBackendMessage {
77
var dataTypes: [PostgresDataType]
88

99
static func decode(from buffer: inout ByteBuffer) throws -> Self {
10-
let parameterCount = try buffer.throwingReadInteger(as: Int16.self)
11-
guard parameterCount >= 0 else {
12-
throw PSQLPartialDecodingError.integerMustBePositiveOrNull(parameterCount)
13-
}
10+
let parameterCount = try buffer.throwingReadInteger(as: UInt16.self)
1411

1512
var result = [PostgresDataType]()
1613
result.reserveCapacity(Int(parameterCount))

Sources/PostgresNIO/New/Messages/Parse.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ extension PostgresFrontendMessage {
1515
func encode(into buffer: inout ByteBuffer) {
1616
buffer.writeNullTerminatedString(self.preparedStatementName)
1717
buffer.writeNullTerminatedString(self.query)
18-
buffer.writeInteger(Int16(self.parameters.count))
18+
buffer.writeInteger(UInt16(self.parameters.count))
1919

2020
self.parameters.forEach { dataType in
2121
buffer.writeInteger(dataType.rawValue)

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,4 +329,60 @@ final class IntegrationTests: XCTestCase {
329329
XCTAssertEqual(obj?.bar, 2)
330330
}
331331
}
332+
333+
#if swift(>=5.5.2)
334+
func testBindMaximumParameters() async throws {
335+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
336+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
337+
let eventLoop = eventLoopGroup.next()
338+
339+
try await withTestConnection(on: eventLoop) { connection in
340+
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
341+
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
342+
// Then we will insert 3 * 17 rows
343+
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
344+
// If the test is successful, it means Postgres supports UInt16.max bindings
345+
let columnsCount = 5 * 257
346+
let rowsCount = 3 * 17
347+
348+
let createQuery = PostgresQuery(
349+
unsafeSQL: """
350+
CREATE TABLE table1 (
351+
\((0..<columnsCount).map({ #""int\#($0)" int NOT NULL"# }).joined(separator: ", "))
352+
);
353+
"""
354+
)
355+
try await connection.query(createQuery, logger: .psqlTest)
356+
357+
var binds = PostgresBindings(capacity: Int(UInt16.max))
358+
for _ in (0..<rowsCount) {
359+
for num in (0..<columnsCount) {
360+
try binds.append(num, context: .default)
361+
}
362+
}
363+
XCTAssertEqual(binds.count, Int(UInt16.max))
364+
365+
let insertionValues = (0..<rowsCount).map { rowIndex in
366+
let indices = (0..<columnsCount).map { columnIndex -> String in
367+
"$\(rowIndex * columnsCount + columnIndex + 1)"
368+
}
369+
return "(\(indices.joined(separator: ", ")))"
370+
}.joined(separator: ", ")
371+
let insertionQuery = PostgresQuery(
372+
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
373+
binds: binds
374+
)
375+
try await connection.query(insertionQuery, logger: .psqlTest)
376+
377+
let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
378+
let countRows = try await connection.query(countQuery, logger: .psqlTest)
379+
var countIterator = countRows.makeAsyncIterator()
380+
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
381+
XCTAssertEqual(rowsCount, insertedRowsCount)
382+
383+
let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
384+
try await connection.query(dropQuery, logger: .psqlTest)
385+
}
386+
}
387+
#endif
332388
}

Tests/IntegrationTests/PostgresNIOTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ final class PostgresNIOTests: XCTestCase {
10611061
var conn: PostgresConnection?
10621062
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
10631063
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
1064-
let binds = [PostgresData].init(repeating: .null, count: Int(Int16.max) + 1)
1064+
let binds = [PostgresData].init(repeating: .null, count: Int(UInt16.max) + 1)
10651065
XCTAssertThrowsError(try conn?.query("SELECT version()", binds).wait()) { error in
10661066
guard case .tooManyParameters = (error as? PSQLError)?.base else {
10671067
return XCTFail("Unexpected error: \(error)")

Tests/PostgresNIOTests/New/Messages/ParseTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class ParseTests: XCTestCase {
2626
XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), Int32(length - 1))
2727
XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.preparedStatementName)
2828
XCTAssertEqual(byteBuffer.readNullTerminatedString(), parse.query)
29-
XCTAssertEqual(byteBuffer.readInteger(as: Int16.self), Int16(parse.parameters.count))
29+
XCTAssertEqual(byteBuffer.readInteger(as: UInt16.self), UInt16(parse.parameters.count))
3030
XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bool.rawValue)
3131
XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.int8.rawValue)
3232
XCTAssertEqual(byteBuffer.readInteger(as: UInt32.self), PostgresDataType.bytea.rawValue)

0 commit comments

Comments
 (0)