@@ -10,8 +10,10 @@ use rustls::ClientConfig;
10
10
use std:: collections:: HashMap ;
11
11
use std:: fmt;
12
12
use std:: io:: BufReader ;
13
+ use std:: num:: NonZeroI32 ;
13
14
use std:: sync:: Arc ;
14
- use tokio:: sync:: OnceCell ;
15
+
16
+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
15
17
use tokio_postgres:: error:: SqlState ;
16
18
use tokio_postgres:: tls:: MakeTlsConnect ;
17
19
use tokio_postgres:: {
@@ -163,7 +165,8 @@ impl PgReplicationSlotTransaction {
163
165
/// and streaming changes from the database.
164
166
#[ derive( Debug , Clone ) ]
165
167
pub struct PgReplicationClient {
166
- client : Arc < ( Client , OnceCell < i32 > ) > ,
168
+ client : Arc < Client > ,
169
+ server_version : Option < NonZeroI32 > ,
167
170
}
168
171
169
172
impl PgReplicationClient {
@@ -178,14 +181,6 @@ impl PgReplicationClient {
178
181
}
179
182
}
180
183
181
- // Convenience method to avoid having to access the client directly.
182
- async fn simple_query (
183
- & self ,
184
- query : & str ,
185
- ) -> Result < Vec < SimpleQueryMessage > , tokio_postgres:: Error > {
186
- self . client . 0 . simple_query ( query) . await
187
- }
188
-
189
184
/// Establishes a connection to Postgres without TLS encryption.
190
185
///
191
186
/// The connection is configured for logical replication mode.
@@ -194,12 +189,16 @@ impl PgReplicationClient {
194
189
config. replication_mode ( ReplicationMode :: Logical ) ;
195
190
196
191
let ( client, connection) = config. connect ( NoTls ) . await ?;
192
+
193
+ let server_version = Self :: extract_server_version ( & connection) ;
194
+
197
195
spawn_postgres_connection :: < NoTls > ( connection) ;
198
196
199
197
info ! ( "successfully connected to postgres without tls" ) ;
200
198
201
199
Ok ( PgReplicationClient {
202
- client : Arc :: new ( ( client, OnceCell :: new ( ) ) ) ,
200
+ client : Arc :: new ( client) ,
201
+ server_version,
203
202
} )
204
203
}
205
204
@@ -225,12 +224,16 @@ impl PgReplicationClient {
225
224
. with_no_client_auth ( ) ;
226
225
227
226
let ( client, connection) = config. connect ( MakeRustlsConnect :: new ( tls_config) ) . await ?;
227
+
228
+ let server_version = Self :: extract_server_version ( & connection) ;
229
+
228
230
spawn_postgres_connection :: < MakeRustlsConnect > ( connection) ;
229
231
230
232
info ! ( "successfully connected to postgres with tls" ) ;
231
233
232
234
Ok ( PgReplicationClient {
233
- client : Arc :: new ( ( client, OnceCell :: new ( ) ) ) ,
235
+ client : Arc :: new ( client) ,
236
+ server_version,
234
237
} )
235
238
}
236
239
@@ -262,7 +265,7 @@ impl PgReplicationClient {
262
265
quote_literal( slot_name)
263
266
) ;
264
267
265
- let results = self . simple_query ( & query) . await ?;
268
+ let results = self . client . simple_query ( & query) . await ?;
266
269
for result in results {
267
270
if let SimpleQueryMessage :: Row ( row) = result {
268
271
let confirmed_flush_lsn = Self :: get_row_value :: < PgLsn > (
@@ -324,7 +327,7 @@ impl PgReplicationClient {
324
327
quote_identifier( slot_name)
325
328
) ;
326
329
327
- match self . simple_query ( & query) . await {
330
+ match self . client . simple_query ( & query) . await {
328
331
Ok ( _) => {
329
332
info ! ( "successfully deleted replication slot '{}'" , slot_name) ;
330
333
@@ -362,7 +365,7 @@ impl PgReplicationClient {
362
365
"select 1 as exists from pg_publication where pubname = {};" ,
363
366
quote_literal( publication)
364
367
) ;
365
- for msg in self . simple_query ( & publication_exists_query) . await ? {
368
+ for msg in self . client . simple_query ( & publication_exists_query) . await ? {
366
369
if let SimpleQueryMessage :: Row ( _) = msg {
367
370
return Ok ( true ) ;
368
371
}
@@ -382,7 +385,7 @@ impl PgReplicationClient {
382
385
) ;
383
386
384
387
let mut table_names = vec ! [ ] ;
385
- for msg in self . simple_query ( & publication_query) . await ? {
388
+ for msg in self . client . simple_query ( & publication_query) . await ? {
386
389
if let SimpleQueryMessage :: Row ( row) = msg {
387
390
let schema =
388
391
Self :: get_row_value :: < String > ( & row, "schemaname" , "pg_publication_tables" )
@@ -412,7 +415,7 @@ impl PgReplicationClient {
412
415
) ;
413
416
414
417
let mut table_ids = vec ! [ ] ;
415
- for msg in self . simple_query ( & publication_query) . await ? {
418
+ for msg in self . client . simple_query ( & publication_query) . await ? {
416
419
if let SimpleQueryMessage :: Row ( row) = msg {
417
420
// For the sake of simplicity, we refer to the table oid as table id.
418
421
let table_id = Self :: get_row_value :: < TableId > ( & row, "oid" , "pg_class" ) . await ?;
@@ -450,11 +453,7 @@ impl PgReplicationClient {
450
453
options
451
454
) ;
452
455
453
- let copy_stream = self
454
- . client
455
- . 0
456
- . copy_both_simple :: < bytes:: Bytes > ( & query)
457
- . await ?;
456
+ let copy_stream = self . client . copy_both_simple :: < bytes:: Bytes > ( & query) . await ?;
458
457
let stream = LogicalReplicationStream :: new ( copy_stream) ;
459
458
460
459
Ok ( stream)
@@ -465,22 +464,23 @@ impl PgReplicationClient {
465
464
/// The transaction doesn't make any assumptions about the snapshot in use, since this is a
466
465
/// concern of the statements issued within the transaction.
467
466
async fn begin_tx ( & self ) -> EtlResult < ( ) > {
468
- self . simple_query ( "begin read only isolation level repeatable read;" )
467
+ self . client
468
+ . simple_query ( "begin read only isolation level repeatable read;" )
469
469
. await ?;
470
470
471
471
Ok ( ( ) )
472
472
}
473
473
474
474
/// Commits the current transaction.
475
475
async fn commit_tx ( & self ) -> EtlResult < ( ) > {
476
- self . simple_query ( "commit;" ) . await ?;
476
+ self . client . simple_query ( "commit;" ) . await ?;
477
477
478
478
Ok ( ( ) )
479
479
}
480
480
481
481
/// Rolls back the current transaction.
482
482
async fn rollback_tx ( & self ) -> EtlResult < ( ) > {
483
- self . simple_query ( "rollback;" ) . await ?;
483
+ self . client . simple_query ( "rollback;" ) . await ?;
484
484
485
485
Ok ( ( ) )
486
486
}
@@ -507,7 +507,7 @@ impl PgReplicationClient {
507
507
quote_identifier( slot_name) ,
508
508
snapshot_option
509
509
) ;
510
- match self . simple_query ( & query) . await {
510
+ match self . client . simple_query ( & query) . await {
511
511
Ok ( results) => {
512
512
for result in results {
513
513
if let SimpleQueryMessage :: Row ( row) = result {
@@ -607,7 +607,7 @@ impl PgReplicationClient {
607
607
where c.oid = {table_id}" ,
608
608
) ;
609
609
610
- for message in self . simple_query ( & table_info_query) . await ? {
610
+ for message in self . client . simple_query ( & table_info_query) . await ? {
611
611
if let SimpleQueryMessage :: Row ( row) = message {
612
612
let schema_name =
613
613
Self :: get_row_value :: < String > ( & row, "schema_name" , "pg_namespace" ) . await ?;
@@ -638,7 +638,12 @@ impl PgReplicationClient {
638
638
publication : Option < & str > ,
639
639
) -> EtlResult < Vec < ColumnSchema > > {
640
640
let ( pub_cte, pub_pred) = if let Some ( publication) = publication {
641
- let is_pg14_or_earlier = self . get_server_version ( ) . await . unwrap_or ( 0 ) < 150000 ;
641
+ let is_pg14_or_earlier = if let Some ( server_version) = self . server_version {
642
+ server_version. get ( ) < 150000
643
+ } else {
644
+ // be conservative by default
645
+ true
646
+ } ;
642
647
643
648
if !is_pg14_or_earlier {
644
649
(
@@ -702,7 +707,7 @@ impl PgReplicationClient {
702
707
703
708
let mut column_schemas = vec ! [ ] ;
704
709
705
- for message in self . simple_query ( & column_info_query) . await ? {
710
+ for message in self . client . simple_query ( & column_info_query) . await ? {
706
711
if let SimpleQueryMessage :: Row ( row) = message {
707
712
let name = Self :: get_row_value :: < String > ( & row, "attname" , "pg_attribute" ) . await ?;
708
713
let type_oid = Self :: get_row_value :: < u32 > ( & row, "atttypid" , "pg_attribute" ) . await ?;
@@ -728,36 +733,46 @@ impl PgReplicationClient {
728
733
Ok ( column_schemas)
729
734
}
730
735
731
- /// Gets the PostgreSQL server version.
736
+ /// Extracts the PostgreSQL server version from connection parameters.
737
+ ///
738
+ /// This method should be called during connection establishment to extract
739
+ /// the server version from the parameter status messages sent by the server.
732
740
///
733
741
/// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
742
+ /// This matches the format used by `SELECT version()`.
734
743
/// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
735
- async fn get_server_version ( & self ) -> EtlResult < i32 > {
736
- let version = self
737
- . client
738
- . 1
739
- . get_or_try_init ( || async {
740
- let version_query = "SHOW server_version_num" ;
741
-
742
- for message in self . simple_query ( version_query) . await ? {
743
- if let SimpleQueryMessage :: Row ( row) = message {
744
- let version_str = Self :: get_row_value :: < String > (
745
- & row,
746
- "server_version_num" ,
747
- "server_version_num" ,
748
- )
749
- . await ?;
750
- let version: i32 = version_str. parse ( ) . unwrap_or ( 0 ) ;
751
- return Ok :: < _ , EtlError > ( version) ;
752
- }
753
- }
754
-
755
- // If we can't determine, return 0 (which will be treated as very old version)
756
- Ok ( 0 )
757
- } )
758
- . await ?;
759
-
760
- Ok ( * version)
744
+ fn extract_server_version < S > ( connection : & Connection < Socket , S > ) -> Option < NonZeroI32 >
745
+ where
746
+ S : AsyncRead + AsyncWrite + Unpin + Send ,
747
+ {
748
+ if let Some ( server_version_str) = connection. parameter ( "server_version" ) {
749
+ // Parse version string like "15.5 (Homebrew)" or "14.2"
750
+ let version_part = server_version_str
751
+ . split_whitespace ( )
752
+ . next ( )
753
+ . unwrap_or ( "0.0" ) ;
754
+
755
+ let version_components: Vec < & str > = version_part. split ( '.' ) . collect ( ) ;
756
+
757
+ let major = version_components
758
+ . get ( 0 )
759
+ . and_then ( |v| v. parse :: < i32 > ( ) . ok ( ) )
760
+ . unwrap_or ( 0 ) ;
761
+ let minor = version_components
762
+ . get ( 1 )
763
+ . and_then ( |v| v. parse :: < i32 > ( ) . ok ( ) )
764
+ . unwrap_or ( 0 ) ;
765
+ let patch = version_components
766
+ . get ( 2 )
767
+ . and_then ( |v| v. parse :: < i32 > ( ) . ok ( ) )
768
+ . unwrap_or ( 0 ) ;
769
+
770
+ let version = major * 10000 + minor * 100 + patch;
771
+
772
+ NonZeroI32 :: new ( version)
773
+ } else {
774
+ None
775
+ }
761
776
}
762
777
763
778
/// Creates a COPY stream for reading data from a table using its OID.
@@ -783,7 +798,7 @@ impl PgReplicationClient {
783
798
column_list
784
799
) ;
785
800
786
- let stream = self . client . 0 . copy_out_simple ( & copy_query) . await ?;
801
+ let stream = self . client . copy_out_simple ( & copy_query) . await ?;
787
802
788
803
Ok ( stream)
789
804
}
0 commit comments