Skip to content

Commit 9a02d74

Browse files
authored
Move PostgresFrontendMessage to tests (#399)
1 parent 12584c6 commit 9a02d74

File tree

4 files changed

+71
-38
lines changed

4 files changed

+71
-38
lines changed

Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,6 @@ import NIOCore
22

33
internal extension ByteBuffer {
44

5-
mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) {
6-
self.writeInteger(messageID.rawValue)
7-
}
8-
9-
mutating func psqlWriteFrontendMessageID(_ messageID: PostgresFrontendMessage.ID) {
10-
self.writeInteger(messageID.rawValue)
11-
}
12-
135
@usableFromInline
146
mutating func psqlReadFloat() -> Float? {
157
return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) }

Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
import NIOCore
22

33
struct PostgresFrontendMessageEncoder {
4+
5+
/// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits,
6+
/// and 5679 in the least significant 16 bits.
7+
static let sslRequestCode: Int32 = 80877103
8+
9+
/// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits,
10+
/// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same
11+
/// as any protocol version number.)
12+
static let cancelRequestCode: Int32 = 80877102
13+
14+
static let startupVersionThree: Int32 = 0x00_03_00_00
15+
416
private enum State {
517
case flushed
618
case writable
@@ -15,8 +27,8 @@ struct PostgresFrontendMessageEncoder {
1527

1628
mutating func startup(user: String, database: String?) {
1729
self.clearIfNeeded()
18-
self.encodeLengthPrefixed { buffer in
19-
buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree)
30+
self.buffer.psqlLengthPrefixed { buffer in
31+
buffer.writeInteger(Self.startupVersionThree)
2032
buffer.writeNullTerminatedString("user")
2133
buffer.writeNullTerminatedString(user)
2234

@@ -31,8 +43,7 @@ struct PostgresFrontendMessageEncoder {
3143

3244
mutating func bind(portalName: String, preparedStatementName: String, bind: PostgresBindings) {
3345
self.clearIfNeeded()
34-
self.buffer.psqlWriteFrontendMessageID(.bind)
35-
self.encodeLengthPrefixed { buffer in
46+
self.buffer.psqlLengthPrefixed(id: .bind) { buffer in
3647
buffer.writeNullTerminatedString(portalName)
3748
buffer.writeNullTerminatedString(preparedStatementName)
3849

@@ -65,45 +76,45 @@ struct PostgresFrontendMessageEncoder {
6576

6677
mutating func cancel(processID: Int32, secretKey: Int32) {
6778
self.clearIfNeeded()
68-
self.buffer.writeMultipleIntegers(UInt32(16), PostgresFrontendMessage.Cancel.requestCode, processID, secretKey)
79+
self.buffer.writeMultipleIntegers(UInt32(16), Self.cancelRequestCode, processID, secretKey)
6980
}
7081

7182
mutating func closePreparedStatement(_ preparedStatement: String) {
7283
self.clearIfNeeded()
73-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S"))
84+
self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S"))
7485
self.buffer.writeNullTerminatedString(preparedStatement)
7586
}
7687

7788
mutating func closePortal(_ portal: String) {
7889
self.clearIfNeeded()
79-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.close.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P"))
90+
self.buffer.psqlWriteMultipleIntegers(id: .close, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P"))
8091
self.buffer.writeNullTerminatedString(portal)
8192
}
8293

8394
mutating func describePreparedStatement(_ preparedStatement: String) {
8495
self.clearIfNeeded()
85-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + preparedStatement.utf8.count), UInt8(ascii: "S"))
96+
self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + preparedStatement.utf8.count), UInt8(ascii: "S"))
8697
self.buffer.writeNullTerminatedString(preparedStatement)
8798
}
8899

89100
mutating func describePortal(_ portal: String) {
90101
self.clearIfNeeded()
91-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.describe.rawValue, UInt32(6 + portal.utf8.count), UInt8(ascii: "P"))
102+
self.buffer.psqlWriteMultipleIntegers(id: .describe, length: UInt32(2 + portal.utf8.count), UInt8(ascii: "P"))
92103
self.buffer.writeNullTerminatedString(portal)
93104
}
94105

95106
mutating func execute(portalName: String, maxNumberOfRows: Int32 = 0) {
96107
self.clearIfNeeded()
97-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.execute.rawValue, UInt32(9 + portalName.utf8.count))
108+
self.buffer.psqlWriteMultipleIntegers(id: .execute, length: UInt32(5 + portalName.utf8.count))
98109
self.buffer.writeNullTerminatedString(portalName)
99110
self.buffer.writeInteger(maxNumberOfRows)
100111
}
101112

102113
mutating func parse<Parameters: Collection>(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType {
103114
self.clearIfNeeded()
104-
self.buffer.writeMultipleIntegers(
105-
PostgresFrontendMessage.ID.parse.rawValue,
106-
UInt32(4 + preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout<PostgresDataType>.size * parameters.count)
115+
self.buffer.psqlWriteMultipleIntegers(
116+
id: .parse,
117+
length: UInt32(preparedStatementName.utf8.count + 1 + query.utf8.count + 1 + 2 + MemoryLayout<PostgresDataType>.size * parameters.count)
107118
)
108119
self.buffer.writeNullTerminatedString(preparedStatementName)
109120
self.buffer.writeNullTerminatedString(query)
@@ -116,28 +127,25 @@ struct PostgresFrontendMessageEncoder {
116127

117128
mutating func password<Bytes: Collection>(_ bytes: Bytes) where Bytes.Element == UInt8 {
118129
self.clearIfNeeded()
119-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.password.rawValue, UInt32(5 + bytes.count))
130+
self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count) + 1)
120131
self.buffer.writeBytes(bytes)
121132
self.buffer.writeInteger(UInt8(0))
122133
}
123134

124135
mutating func flush() {
125136
self.clearIfNeeded()
126-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.flush.rawValue, UInt32(4))
137+
self.buffer.psqlWriteMultipleIntegers(id: .flush, length: 0)
127138
}
128139

129140
mutating func saslResponse<Bytes: Collection>(_ bytes: Bytes) where Bytes.Element == UInt8 {
130141
self.clearIfNeeded()
131-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.saslResponse.rawValue, UInt32(4 + bytes.count))
142+
self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(bytes.count))
132143
self.buffer.writeBytes(bytes)
133144
}
134145

135146
mutating func saslInitialResponse<Bytes: Collection>(mechanism: String, bytes: Bytes) where Bytes.Element == UInt8 {
136147
self.clearIfNeeded()
137-
self.buffer.writeMultipleIntegers(
138-
PostgresFrontendMessage.ID.saslInitialResponse.rawValue,
139-
UInt32(4 + mechanism.utf8.count + 1 + 4 + bytes.count)
140-
)
148+
self.buffer.psqlWriteMultipleIntegers(id: .password, length: UInt32(mechanism.utf8.count + 1 + 4 + bytes.count))
141149
self.buffer.writeNullTerminatedString(mechanism)
142150
if bytes.count > 0 {
143151
self.buffer.writeInteger(Int32(bytes.count))
@@ -149,17 +157,17 @@ struct PostgresFrontendMessageEncoder {
149157

150158
mutating func ssl() {
151159
self.clearIfNeeded()
152-
self.buffer.writeMultipleIntegers(UInt32(8), PostgresFrontendMessage.SSLRequest.requestCode)
160+
self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode)
153161
}
154162

155163
mutating func sync() {
156164
self.clearIfNeeded()
157-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.sync.rawValue, UInt32(4))
165+
self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0)
158166
}
159167

160168
mutating func terminate() {
161169
self.clearIfNeeded()
162-
self.buffer.writeMultipleIntegers(PostgresFrontendMessage.ID.terminate.rawValue, UInt32(4))
170+
self.buffer.psqlWriteMultipleIntegers(id: .terminate, length: 0)
163171
}
164172

165173
mutating func flushBuffer() -> ByteBuffer {
@@ -177,13 +185,42 @@ struct PostgresFrontendMessageEncoder {
177185
break
178186
}
179187
}
188+
}
180189

181-
private mutating func encodeLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) {
182-
let startIndex = self.buffer.writerIndex
183-
self.buffer.writeInteger(UInt32(0)) // placeholder for length
184-
encode(&self.buffer)
185-
let length = UInt32(self.buffer.writerIndex - startIndex)
186-
self.buffer.setInteger(length, at: startIndex)
190+
private enum FrontendMessageID: UInt8, Hashable, Sendable {
191+
case bind = 66 // B
192+
case close = 67 // C
193+
case describe = 68 // D
194+
case execute = 69 // E
195+
case flush = 72 // H
196+
case parse = 80 // P
197+
case password = 112 // p - also both sasl values
198+
case sync = 83 // S
199+
case terminate = 88 // X
200+
}
201+
202+
extension ByteBuffer {
203+
mutating fileprivate func psqlWriteMultipleIntegers(id: FrontendMessageID, length: UInt32) {
204+
self.writeMultipleIntegers(id.rawValue, 4 + length)
205+
}
206+
207+
mutating fileprivate func psqlWriteMultipleIntegers<T1: FixedWidthInteger>(id: FrontendMessageID, length: UInt32, _ t1: T1) {
208+
self.writeMultipleIntegers(id.rawValue, 4 + length, t1)
187209
}
188210

211+
mutating fileprivate func psqlLengthPrefixed(id: FrontendMessageID, _ encode: (inout ByteBuffer) -> ()) {
212+
let lengthIndex = self.writerIndex + 1
213+
self.psqlWriteMultipleIntegers(id: id, length: 0)
214+
encode(&self)
215+
let length = UInt32(self.writerIndex - lengthIndex)
216+
self.setInteger(length, at: lengthIndex)
217+
}
218+
219+
mutating fileprivate func psqlLengthPrefixed(_ encode: (inout ByteBuffer) -> ()) {
220+
let lengthIndex = self.writerIndex
221+
self.writeInteger(UInt32(0)) // placeholder
222+
encode(&self)
223+
let length = UInt32(self.writerIndex - lengthIndex)
224+
self.setInteger(length, at: lengthIndex)
225+
}
189226
}

Tests/PostgresNIOTests/New/Extensions/ByteBuffer+Utils.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ import NIOCore
22
@testable import PostgresNIO
33

44
extension ByteBuffer {
5-
5+
mutating func psqlWriteBackendMessageID(_ messageID: PostgresBackendMessage.ID) {
6+
self.writeInteger(messageID.rawValue)
7+
}
8+
69
static func backendMessage(id: PostgresBackendMessage.ID, _ payload: (inout ByteBuffer) throws -> ()) rethrows -> ByteBuffer {
710
var byteBuffer = ByteBuffer()
811
try byteBuffer.writeBackendMessage(id: id, payload)

Sources/PostgresNIO/New/PostgresFrontendMessage.swift renamed to Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import NIOCore
2+
import PostgresNIO
23

34
/// A wire message that is created by a Postgres client to be consumed by Postgres server.
45
///

0 commit comments

Comments
 (0)