Skip to content

Commit 8b13752

Browse files
authored
Add PSQLBackendMessageEncoder (#175)
### Motivation To test a `PSQLChannelHandler`, that uses an internal `NIOSingleStepByteToMessageDecoder`, in an `EmbeddedChannel` we need to writeInbound bytes. To make this easier, this PR introduces a `PSQLBackendMessageEncoder` as a test util ### Changes - Add `PSQLBackendMessageEncoder` - Use `PSQLBackendMessageEncoder` in authentication tests
1 parent ce57b02 commit 8b13752

File tree

2 files changed

+280
-26
lines changed

2 files changed

+280
-26
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import NIOCore
2+
@testable import PostgresNIO
3+
4+
struct PSQLBackendMessageEncoder: MessageToByteEncoder {
5+
typealias OutboundIn = PSQLBackendMessage
6+
7+
/// Called once there is data to encode.
8+
///
9+
/// - parameters:
10+
/// - data: The data to encode into a `ByteBuffer`.
11+
/// - out: The `ByteBuffer` into which we want to encode.
12+
func encode(data message: PSQLBackendMessage, out buffer: inout ByteBuffer) throws {
13+
switch message {
14+
case .authentication(let authentication):
15+
self.encode(messageID: message.id, payload: authentication, into: &buffer)
16+
17+
case .backendKeyData(let keyData):
18+
self.encode(messageID: message.id, payload: keyData, into: &buffer)
19+
20+
case .bindComplete,
21+
.closeComplete,
22+
.emptyQueryResponse,
23+
.noData,
24+
.parseComplete,
25+
.portalSuspended:
26+
self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer)
27+
28+
case .commandComplete(let string):
29+
self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer)
30+
31+
case .dataRow(let row):
32+
self.encode(messageID: message.id, payload: row, into: &buffer)
33+
34+
case .error(let errorResponse):
35+
self.encode(messageID: message.id, payload: errorResponse, into: &buffer)
36+
37+
case .notice(let noticeResponse):
38+
self.encode(messageID: message.id, payload: noticeResponse, into: &buffer)
39+
40+
case .notification(let notificationResponse):
41+
self.encode(messageID: message.id, payload: notificationResponse, into: &buffer)
42+
43+
case .parameterDescription(let description):
44+
self.encode(messageID: message.id, payload: description, into: &buffer)
45+
46+
case .parameterStatus(let status):
47+
self.encode(messageID: message.id, payload: status, into: &buffer)
48+
49+
case .readyForQuery(let transactionState):
50+
self.encode(messageID: message.id, payload: transactionState, into: &buffer)
51+
52+
case .rowDescription(let description):
53+
self.encode(messageID: message.id, payload: description, into: &buffer)
54+
55+
case .sslSupported:
56+
buffer.writeInteger(UInt8(ascii: "S"))
57+
58+
case .sslUnsupported:
59+
buffer.writeInteger(UInt8(ascii: "N"))
60+
}
61+
}
62+
63+
private struct EmptyPayload: PSQLMessagePayloadEncodable {
64+
func encode(into buffer: inout ByteBuffer) {}
65+
}
66+
67+
private struct StringPayload: PSQLMessagePayloadEncodable {
68+
var string: String
69+
init(_ string: String) { self.string = string }
70+
func encode(into buffer: inout ByteBuffer) {
71+
buffer.writeNullTerminatedString(self.string)
72+
}
73+
}
74+
75+
private func encode<Payload: PSQLMessagePayloadEncodable>(
76+
messageID: PSQLBackendMessage.ID,
77+
payload: Payload,
78+
into buffer: inout ByteBuffer)
79+
{
80+
buffer.writeBackendMessageID(messageID)
81+
let startIndex = buffer.writerIndex
82+
buffer.writeInteger(Int32(0)) // placeholder for length
83+
payload.encode(into: &buffer)
84+
let length = Int32(buffer.writerIndex - startIndex)
85+
buffer.setInteger(length, at: startIndex)
86+
}
87+
}
88+
89+
extension PSQLBackendMessage {
90+
var id: ID {
91+
switch self {
92+
case .authentication:
93+
return .authentication
94+
case .backendKeyData:
95+
return .backendKeyData
96+
case .bindComplete:
97+
return .bindComplete
98+
case .closeComplete:
99+
return .closeComplete
100+
case .commandComplete:
101+
return .commandComplete
102+
case .dataRow:
103+
return .dataRow
104+
case .emptyQueryResponse:
105+
return .emptyQueryResponse
106+
case .error:
107+
return .error
108+
case .noData:
109+
return .noData
110+
case .notice:
111+
return .noticeResponse
112+
case .notification:
113+
return .notificationResponse
114+
case .parameterDescription:
115+
return .parameterDescription
116+
case .parameterStatus:
117+
return .parameterStatus
118+
case .parseComplete:
119+
return .parseComplete
120+
case .portalSuspended:
121+
return .portalSuspended
122+
case .readyForQuery:
123+
return .readyForQuery
124+
case .rowDescription:
125+
return .rowDescription
126+
case .sslSupported,
127+
.sslUnsupported:
128+
preconditionFailure("Message has no id.")
129+
}
130+
}
131+
}
132+
133+
extension PSQLBackendMessage.Authentication: PSQLMessagePayloadEncodable {
134+
135+
public func encode(into buffer: inout ByteBuffer) {
136+
switch self {
137+
case .ok:
138+
buffer.writeInteger(Int32(0))
139+
140+
case .kerberosV5:
141+
buffer.writeInteger(Int32(2))
142+
143+
case .plaintext:
144+
buffer.writeInteger(Int32(3))
145+
146+
case .md5(salt: let salt):
147+
buffer.writeInteger(Int32(5))
148+
buffer.writeInteger(salt.0)
149+
buffer.writeInteger(salt.1)
150+
buffer.writeInteger(salt.2)
151+
buffer.writeInteger(salt.3)
152+
153+
case .scmCredential:
154+
buffer.writeInteger(Int32(6))
155+
156+
case .gss:
157+
buffer.writeInteger(Int32(7))
158+
159+
case .gssContinue(var data):
160+
buffer.writeInteger(Int32(8))
161+
buffer.writeBuffer(&data)
162+
163+
case .sspi:
164+
buffer.writeInteger(Int32(9))
165+
166+
case .sasl(names: let names):
167+
buffer.writeInteger(Int32(10))
168+
for name in names {
169+
buffer.writeNullTerminatedString(name)
170+
}
171+
172+
case .saslContinue(data: var data):
173+
buffer.writeInteger(Int32(11))
174+
buffer.writeBuffer(&data)
175+
176+
case .saslFinal(data: var data):
177+
buffer.writeInteger(Int32(12))
178+
buffer.writeBuffer(&data)
179+
}
180+
}
181+
182+
}
183+
184+
extension PSQLBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable {
185+
public func encode(into buffer: inout ByteBuffer) {
186+
buffer.writeInteger(self.processID)
187+
buffer.writeInteger(self.secretKey)
188+
}
189+
}
190+
191+
extension PSQLBackendMessage.DataRow: PSQLMessagePayloadEncodable {
192+
public func encode(into buffer: inout ByteBuffer) {
193+
buffer.writeInteger(Int16(self.columns.count))
194+
195+
for column in self.columns {
196+
switch column {
197+
case .none:
198+
buffer.writeInteger(-1, as: Int32.self)
199+
case .some(var writable):
200+
buffer.writeInteger(Int32(writable.readableBytes))
201+
buffer.writeBuffer(&writable)
202+
}
203+
}
204+
}
205+
}
206+
207+
extension PSQLBackendMessage.ErrorResponse: PSQLMessagePayloadEncodable {
208+
public func encode(into buffer: inout ByteBuffer) {
209+
for (key, value) in self.fields {
210+
buffer.writeInteger(key.rawValue, as: UInt8.self)
211+
buffer.writeNullTerminatedString(value)
212+
}
213+
buffer.writeInteger(0, as: UInt8.self) // signal done
214+
}
215+
}
216+
217+
extension PSQLBackendMessage.NoticeResponse: PSQLMessagePayloadEncodable {
218+
public func encode(into buffer: inout ByteBuffer) {
219+
for (key, value) in self.fields {
220+
buffer.writeInteger(key.rawValue, as: UInt8.self)
221+
buffer.writeNullTerminatedString(value)
222+
}
223+
buffer.writeInteger(0, as: UInt8.self) // signal done
224+
}
225+
}
226+
227+
extension PSQLBackendMessage.NotificationResponse: PSQLMessagePayloadEncodable {
228+
public func encode(into buffer: inout ByteBuffer) {
229+
buffer.writeInteger(self.backendPID)
230+
buffer.writeNullTerminatedString(self.channel)
231+
buffer.writeNullTerminatedString(self.payload)
232+
}
233+
}
234+
235+
extension PSQLBackendMessage.ParameterDescription: PSQLMessagePayloadEncodable {
236+
public func encode(into buffer: inout ByteBuffer) {
237+
buffer.writeInteger(Int16(self.dataTypes.count))
238+
239+
for dataType in self.dataTypes {
240+
buffer.writeInteger(dataType.rawValue)
241+
}
242+
}
243+
}
244+
245+
extension PSQLBackendMessage.ParameterStatus: PSQLMessagePayloadEncodable {
246+
public func encode(into buffer: inout ByteBuffer) {
247+
buffer.writeNullTerminatedString(self.parameter)
248+
buffer.writeNullTerminatedString(self.value)
249+
}
250+
}
251+
252+
extension PSQLBackendMessage.TransactionState: PSQLMessagePayloadEncodable {
253+
public func encode(into buffer: inout ByteBuffer) {
254+
buffer.writeInteger(self.rawValue)
255+
}
256+
}
257+
258+
extension PSQLBackendMessage.RowDescription: PSQLMessagePayloadEncodable {
259+
public func encode(into buffer: inout ByteBuffer) {
260+
buffer.writeInteger(Int16(self.columns.count))
261+
262+
for column in self.columns {
263+
buffer.writeNullTerminatedString(column.name)
264+
buffer.writeInteger(column.tableOID)
265+
buffer.writeInteger(column.columnAttributeNumber)
266+
buffer.writeInteger(column.dataType.rawValue)
267+
buffer.writeInteger(column.dataTypeSize)
268+
buffer.writeInteger(column.dataTypeModifier)
269+
buffer.writeInteger(column.format.rawValue)
270+
}
271+
}
272+
}

Tests/PostgresNIOTests/New/Messages/AuthenticationTests.swift

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,54 +8,36 @@ class AuthenticationTests: XCTestCase {
88
func testDecodeAuthentication() {
99
var expected = [PSQLBackendMessage]()
1010
var buffer = ByteBuffer()
11+
let encoder = PSQLBackendMessageEncoder()
1112

1213
// add ok
13-
buffer.writeBackendMessage(id: .authentication) { buffer in
14-
buffer.writeInteger(Int32(0))
15-
}
14+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.ok), out: &buffer))
1615
expected.append(.authentication(.ok))
1716

1817
// add kerberos
19-
buffer.writeBackendMessage(id: .authentication) { buffer in
20-
buffer.writeInteger(Int32(2))
21-
}
18+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.kerberosV5), out: &buffer))
2219
expected.append(.authentication(.kerberosV5))
2320

2421
// add plaintext
25-
buffer.writeBackendMessage(id: .authentication) { buffer in
26-
buffer.writeInteger(Int32(3))
27-
}
22+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.plaintext), out: &buffer))
2823
expected.append(.authentication(.plaintext))
2924

3025
// add md5
31-
buffer.writeBackendMessage(id: .authentication) { buffer in
32-
buffer.writeInteger(Int32(5))
33-
buffer.writeInteger(UInt8(1))
34-
buffer.writeInteger(UInt8(2))
35-
buffer.writeInteger(UInt8(3))
36-
buffer.writeInteger(UInt8(4))
37-
}
26+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.md5(salt: (1, 2, 3, 4))), out: &buffer))
3827
expected.append(.authentication(.md5(salt: (1, 2, 3, 4))))
3928

4029
// add scm credential
41-
buffer.writeBackendMessage(id: .authentication) { buffer in
42-
buffer.writeInteger(Int32(6))
43-
}
30+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.scmCredential), out: &buffer))
4431
expected.append(.authentication(.scmCredential))
4532

4633
// add gss
47-
buffer.writeBackendMessage(id: .authentication) { buffer in
48-
buffer.writeInteger(Int32(7))
49-
}
34+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.gss), out: &buffer))
5035
expected.append(.authentication(.gss))
5136

5237
// add sspi
53-
buffer.writeBackendMessage(id: .authentication) { buffer in
54-
buffer.writeInteger(Int32(9))
55-
}
38+
XCTAssertNoThrow(try encoder.encode(data: .authentication(.sspi), out: &buffer))
5639
expected.append(.authentication(.sspi))
5740

58-
5941
XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(
6042
inputOutputPairs: [(buffer, expected)],
6143
decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) }))

0 commit comments

Comments
 (0)