11import NIOCore
22
33struct PostgresFrontendMessageEncoder {
4+
5+ /// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits,
6+ /// and 5679 in the least significant 16 bits.
7+ static let sslRequestCode : Int32 = 80877103
8+
9+ /// The cancel request code. The value is chosen to contain 1234 in the most significant 16 bits,
10+ /// and 5678 in the least significant 16 bits. (To avoid confusion, this code must not be the same
11+ /// as any protocol version number.)
12+ static let cancelRequestCode : Int32 = 80877102
13+
14+ static let startupVersionThree : Int32 = 0x00_03_00_00
15+
416 private enum State {
517 case flushed
618 case writable
@@ -15,8 +27,8 @@ struct PostgresFrontendMessageEncoder {
1527
1628 mutating func startup( user: String , database: String ? ) {
1729 self . clearIfNeeded ( )
18- self . encodeLengthPrefixed { buffer in
19- buffer. writeInteger ( PostgresFrontendMessage . Startup . versionThree )
30+ self . buffer . psqlLengthPrefixed { buffer in
31+ buffer. writeInteger ( Self . startupVersionThree )
2032 buffer. writeNullTerminatedString ( " user " )
2133 buffer. writeNullTerminatedString ( user)
2234
@@ -31,8 +43,7 @@ struct PostgresFrontendMessageEncoder {
3143
3244 mutating func bind( portalName: String , preparedStatementName: String , bind: PostgresBindings ) {
3345 self . clearIfNeeded ( )
34- self . buffer. psqlWriteFrontendMessageID ( . bind)
35- self . encodeLengthPrefixed { buffer in
46+ self . buffer. psqlLengthPrefixed ( id: . bind) { buffer in
3647 buffer. writeNullTerminatedString ( portalName)
3748 buffer. writeNullTerminatedString ( preparedStatementName)
3849
@@ -65,45 +76,45 @@ struct PostgresFrontendMessageEncoder {
6576
6677 mutating func cancel( processID: Int32 , secretKey: Int32 ) {
6778 self . clearIfNeeded ( )
68- self . buffer. writeMultipleIntegers ( UInt32 ( 16 ) , PostgresFrontendMessage . Cancel . requestCode , processID, secretKey)
79+ self . buffer. writeMultipleIntegers ( UInt32 ( 16 ) , Self . cancelRequestCode , processID, secretKey)
6980 }
7081
7182 mutating func closePreparedStatement( _ preparedStatement: String ) {
7283 self . clearIfNeeded ( )
73- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . close. rawValue , UInt32 ( 6 + preparedStatement. utf8. count) , UInt8 ( ascii: " S " ) )
84+ self . buffer. psqlWriteMultipleIntegers ( id : . close, length : UInt32 ( 2 + preparedStatement. utf8. count) , UInt8 ( ascii: " S " ) )
7485 self . buffer. writeNullTerminatedString ( preparedStatement)
7586 }
7687
7788 mutating func closePortal( _ portal: String ) {
7889 self . clearIfNeeded ( )
79- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . close. rawValue , UInt32 ( 6 + portal. utf8. count) , UInt8 ( ascii: " P " ) )
90+ self . buffer. psqlWriteMultipleIntegers ( id : . close, length : UInt32 ( 2 + portal. utf8. count) , UInt8 ( ascii: " P " ) )
8091 self . buffer. writeNullTerminatedString ( portal)
8192 }
8293
8394 mutating func describePreparedStatement( _ preparedStatement: String ) {
8495 self . clearIfNeeded ( )
85- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . describe. rawValue , UInt32 ( 6 + preparedStatement. utf8. count) , UInt8 ( ascii: " S " ) )
96+ self . buffer. psqlWriteMultipleIntegers ( id : . describe, length : UInt32 ( 2 + preparedStatement. utf8. count) , UInt8 ( ascii: " S " ) )
8697 self . buffer. writeNullTerminatedString ( preparedStatement)
8798 }
8899
89100 mutating func describePortal( _ portal: String ) {
90101 self . clearIfNeeded ( )
91- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . describe. rawValue , UInt32 ( 6 + portal. utf8. count) , UInt8 ( ascii: " P " ) )
102+ self . buffer. psqlWriteMultipleIntegers ( id : . describe, length : UInt32 ( 2 + portal. utf8. count) , UInt8 ( ascii: " P " ) )
92103 self . buffer. writeNullTerminatedString ( portal)
93104 }
94105
95106 mutating func execute( portalName: String , maxNumberOfRows: Int32 = 0 ) {
96107 self . clearIfNeeded ( )
97- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . execute. rawValue , UInt32 ( 9 + portalName. utf8. count) )
108+ self . buffer. psqlWriteMultipleIntegers ( id : . execute, length : UInt32 ( 5 + portalName. utf8. count) )
98109 self . buffer. writeNullTerminatedString ( portalName)
99110 self . buffer. writeInteger ( maxNumberOfRows)
100111 }
101112
102113 mutating func parse< Parameters: Collection > ( preparedStatementName: String , query: String , parameters: Parameters ) where Parameters. Element == PostgresDataType {
103114 self . clearIfNeeded ( )
104- self . buffer. writeMultipleIntegers (
105- PostgresFrontendMessage . ID . parse. rawValue ,
106- UInt32 ( 4 + preparedStatementName. utf8. count + 1 + query. utf8. count + 1 + 2 + MemoryLayout < PostgresDataType > . size * parameters. count)
115+ self . buffer. psqlWriteMultipleIntegers (
116+ id : . parse,
117+ length : UInt32 ( preparedStatementName. utf8. count + 1 + query. utf8. count + 1 + 2 + MemoryLayout < PostgresDataType > . size * parameters. count)
107118 )
108119 self . buffer. writeNullTerminatedString ( preparedStatementName)
109120 self . buffer. writeNullTerminatedString ( query)
@@ -116,28 +127,25 @@ struct PostgresFrontendMessageEncoder {
116127
117128 mutating func password< Bytes: Collection > ( _ bytes: Bytes ) where Bytes. Element == UInt8 {
118129 self . clearIfNeeded ( )
119- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . password. rawValue , UInt32 ( 5 + bytes. count) )
130+ self . buffer. psqlWriteMultipleIntegers ( id : . password, length : UInt32 ( bytes. count) + 1 )
120131 self . buffer. writeBytes ( bytes)
121132 self . buffer. writeInteger ( UInt8 ( 0 ) )
122133 }
123134
124135 mutating func flush( ) {
125136 self . clearIfNeeded ( )
126- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . flush. rawValue , UInt32 ( 4 ) )
137+ self . buffer. psqlWriteMultipleIntegers ( id : . flush, length : 0 )
127138 }
128139
129140 mutating func saslResponse< Bytes: Collection > ( _ bytes: Bytes ) where Bytes. Element == UInt8 {
130141 self . clearIfNeeded ( )
131- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . saslResponse . rawValue , UInt32 ( 4 + bytes. count) )
142+ self . buffer. psqlWriteMultipleIntegers ( id : . password , length : UInt32 ( bytes. count) )
132143 self . buffer. writeBytes ( bytes)
133144 }
134145
135146 mutating func saslInitialResponse< Bytes: Collection > ( mechanism: String , bytes: Bytes ) where Bytes. Element == UInt8 {
136147 self . clearIfNeeded ( )
137- self . buffer. writeMultipleIntegers (
138- PostgresFrontendMessage . ID. saslInitialResponse. rawValue,
139- UInt32 ( 4 + mechanism. utf8. count + 1 + 4 + bytes. count)
140- )
148+ self . buffer. psqlWriteMultipleIntegers ( id: . password, length: UInt32 ( mechanism. utf8. count + 1 + 4 + bytes. count) )
141149 self . buffer. writeNullTerminatedString ( mechanism)
142150 if bytes. count > 0 {
143151 self . buffer. writeInteger ( Int32 ( bytes. count) )
@@ -149,17 +157,17 @@ struct PostgresFrontendMessageEncoder {
149157
150158 mutating func ssl( ) {
151159 self . clearIfNeeded ( )
152- self . buffer. writeMultipleIntegers ( UInt32 ( 8 ) , PostgresFrontendMessage . SSLRequest . requestCode )
160+ self . buffer. writeMultipleIntegers ( UInt32 ( 8 ) , Self . sslRequestCode )
153161 }
154162
155163 mutating func sync( ) {
156164 self . clearIfNeeded ( )
157- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . sync. rawValue , UInt32 ( 4 ) )
165+ self . buffer. psqlWriteMultipleIntegers ( id : . sync, length : 0 )
158166 }
159167
160168 mutating func terminate( ) {
161169 self . clearIfNeeded ( )
162- self . buffer. writeMultipleIntegers ( PostgresFrontendMessage . ID . terminate. rawValue , UInt32 ( 4 ) )
170+ self . buffer. psqlWriteMultipleIntegers ( id : . terminate, length : 0 )
163171 }
164172
165173 mutating func flushBuffer( ) -> ByteBuffer {
@@ -177,13 +185,42 @@ struct PostgresFrontendMessageEncoder {
177185 break
178186 }
179187 }
188+ }
180189
181- private mutating func encodeLengthPrefixed( _ encode: ( inout ByteBuffer ) -> ( ) ) {
182- let startIndex = self . buffer. writerIndex
183- self . buffer. writeInteger ( UInt32 ( 0 ) ) // placeholder for length
184- encode ( & self . buffer)
185- let length = UInt32 ( self . buffer. writerIndex - startIndex)
186- self . buffer. setInteger ( length, at: startIndex)
190+ private enum FrontendMessageID : UInt8 , Hashable , Sendable {
191+ case bind = 66 // B
192+ case close = 67 // C
193+ case describe = 68 // D
194+ case execute = 69 // E
195+ case flush = 72 // H
196+ case parse = 80 // P
197+ case password = 112 // p - also both sasl values
198+ case sync = 83 // S
199+ case terminate = 88 // X
200+ }
201+
202+ extension ByteBuffer {
203+ mutating fileprivate func psqlWriteMultipleIntegers( id: FrontendMessageID , length: UInt32 ) {
204+ self . writeMultipleIntegers ( id. rawValue, 4 + length)
205+ }
206+
207+ mutating fileprivate func psqlWriteMultipleIntegers< T1: FixedWidthInteger > ( id: FrontendMessageID , length: UInt32 , _ t1: T1 ) {
208+ self . writeMultipleIntegers ( id. rawValue, 4 + length, t1)
187209 }
188210
211+ mutating fileprivate func psqlLengthPrefixed( id: FrontendMessageID , _ encode: ( inout ByteBuffer ) -> ( ) ) {
212+ let lengthIndex = self . writerIndex + 1
213+ self . psqlWriteMultipleIntegers ( id: id, length: 0 )
214+ encode ( & self )
215+ let length = UInt32 ( self . writerIndex - lengthIndex)
216+ self . setInteger ( length, at: lengthIndex)
217+ }
218+
219+ mutating fileprivate func psqlLengthPrefixed( _ encode: ( inout ByteBuffer ) -> ( ) ) {
220+ let lengthIndex = self . writerIndex
221+ self . writeInteger ( UInt32 ( 0 ) ) // placeholder
222+ encode ( & self )
223+ let length = UInt32 ( self . writerIndex - lengthIndex)
224+ self . setInteger ( length, at: lengthIndex)
225+ }
189226}
0 commit comments