@@ -11,6 +11,7 @@ use std::collections::HashMap;
11
11
use std:: fmt;
12
12
use std:: io:: BufReader ;
13
13
use std:: sync:: Arc ;
14
+ use tokio:: sync:: OnceCell ;
14
15
use tokio_postgres:: error:: SqlState ;
15
16
use tokio_postgres:: tls:: MakeTlsConnect ;
16
17
use tokio_postgres:: {
@@ -162,7 +163,7 @@ impl PgReplicationSlotTransaction {
162
163
/// and streaming changes from the database.
163
164
#[ derive( Debug , Clone ) ]
164
165
pub struct PgReplicationClient {
165
- client : Arc < Client > ,
166
+ client : Arc < ( Client , OnceCell < i32 > ) > ,
166
167
}
167
168
168
169
impl PgReplicationClient {
@@ -177,6 +178,15 @@ impl PgReplicationClient {
177
178
}
178
179
}
179
180
181
+
182
+ // Convenience method to avoid having to access the client directly.
183
+ async fn simple_query (
184
+ & self ,
185
+ query : & str ,
186
+ ) -> Result < Vec < SimpleQueryMessage > , tokio_postgres:: Error > {
187
+ self . client . 0 . simple_query ( query) . await
188
+ }
189
+
180
190
/// Establishes a connection to Postgres without TLS encryption.
181
191
///
182
192
/// The connection is configured for logical replication mode.
@@ -190,7 +200,7 @@ impl PgReplicationClient {
190
200
info ! ( "successfully connected to postgres without tls" ) ;
191
201
192
202
Ok ( PgReplicationClient {
193
- client : Arc :: new ( client) ,
203
+ client : Arc :: new ( ( client, OnceCell :: new ( ) ) ) ,
194
204
} )
195
205
}
196
206
@@ -221,7 +231,7 @@ impl PgReplicationClient {
221
231
info ! ( "successfully connected to postgres with tls" ) ;
222
232
223
233
Ok ( PgReplicationClient {
224
- client : Arc :: new ( client) ,
234
+ client : Arc :: new ( ( client, OnceCell :: new ( ) ) ) ,
225
235
} )
226
236
}
227
237
@@ -253,7 +263,7 @@ impl PgReplicationClient {
253
263
quote_literal( slot_name)
254
264
) ;
255
265
256
- let results = self . client . simple_query ( & query) . await ?;
266
+ let results = self . simple_query ( & query) . await ?;
257
267
for result in results {
258
268
if let SimpleQueryMessage :: Row ( row) = result {
259
269
let confirmed_flush_lsn = Self :: get_row_value :: < PgLsn > (
@@ -315,7 +325,7 @@ impl PgReplicationClient {
315
325
quote_identifier( slot_name)
316
326
) ;
317
327
318
- match self . client . simple_query ( & query) . await {
328
+ match self . simple_query ( & query) . await {
319
329
Ok ( _) => {
320
330
info ! ( "successfully deleted replication slot '{}'" , slot_name) ;
321
331
@@ -353,7 +363,7 @@ impl PgReplicationClient {
353
363
"select 1 as exists from pg_publication where pubname = {};" ,
354
364
quote_literal( publication)
355
365
) ;
356
- for msg in self . client . simple_query ( & publication_exists_query) . await ? {
366
+ for msg in self . simple_query ( & publication_exists_query) . await ? {
357
367
if let SimpleQueryMessage :: Row ( _) = msg {
358
368
return Ok ( true ) ;
359
369
}
@@ -373,7 +383,7 @@ impl PgReplicationClient {
373
383
) ;
374
384
375
385
let mut table_names = vec ! [ ] ;
376
- for msg in self . client . simple_query ( & publication_query) . await ? {
386
+ for msg in self . simple_query ( & publication_query) . await ? {
377
387
if let SimpleQueryMessage :: Row ( row) = msg {
378
388
let schema =
379
389
Self :: get_row_value :: < String > ( & row, "schemaname" , "pg_publication_tables" )
@@ -403,7 +413,7 @@ impl PgReplicationClient {
403
413
) ;
404
414
405
415
let mut table_ids = vec ! [ ] ;
406
- for msg in self . client . simple_query ( & publication_query) . await ? {
416
+ for msg in self . simple_query ( & publication_query) . await ? {
407
417
if let SimpleQueryMessage :: Row ( row) = msg {
408
418
// For the sake of simplicity, we refer to the table oid as table id.
409
419
let table_id = Self :: get_row_value :: < TableId > ( & row, "oid" , "pg_class" ) . await ?;
@@ -441,7 +451,11 @@ impl PgReplicationClient {
441
451
options
442
452
) ;
443
453
444
- let copy_stream = self . client . copy_both_simple :: < bytes:: Bytes > ( & query) . await ?;
454
+ let copy_stream = self
455
+ . client
456
+ . 0
457
+ . copy_both_simple :: < bytes:: Bytes > ( & query)
458
+ . await ?;
445
459
let stream = LogicalReplicationStream :: new ( copy_stream) ;
446
460
447
461
Ok ( stream)
@@ -452,23 +466,22 @@ impl PgReplicationClient {
452
466
/// The transaction doesn't make any assumptions about the snapshot in use, since this is a
453
467
/// concern of the statements issued within the transaction.
454
468
async fn begin_tx ( & self ) -> EtlResult < ( ) > {
455
- self . client
456
- . simple_query ( "begin read only isolation level repeatable read;" )
469
+ self . simple_query ( "begin read only isolation level repeatable read;" )
457
470
. await ?;
458
471
459
472
Ok ( ( ) )
460
473
}
461
474
462
475
/// Commits the current transaction.
463
476
async fn commit_tx ( & self ) -> EtlResult < ( ) > {
464
- self . client . simple_query ( "commit;" ) . await ?;
477
+ self . simple_query ( "commit;" ) . await ?;
465
478
466
479
Ok ( ( ) )
467
480
}
468
481
469
482
/// Rolls back the current transaction.
470
483
async fn rollback_tx ( & self ) -> EtlResult < ( ) > {
471
- self . client . simple_query ( "rollback;" ) . await ?;
484
+ self . simple_query ( "rollback;" ) . await ?;
472
485
473
486
Ok ( ( ) )
474
487
}
@@ -495,7 +508,7 @@ impl PgReplicationClient {
495
508
quote_identifier( slot_name) ,
496
509
snapshot_option
497
510
) ;
498
- match self . client . simple_query ( & query) . await {
511
+ match self . simple_query ( & query) . await {
499
512
Ok ( results) => {
500
513
for result in results {
501
514
if let SimpleQueryMessage :: Row ( row) = result {
@@ -595,7 +608,7 @@ impl PgReplicationClient {
595
608
where c.oid = {table_id}" ,
596
609
) ;
597
610
598
- for message in self . client . simple_query ( & table_info_query) . await ? {
611
+ for message in self . simple_query ( & table_info_query) . await ? {
599
612
if let SimpleQueryMessage :: Row ( row) = message {
600
613
let schema_name =
601
614
Self :: get_row_value :: < String > ( & row, "schema_name" , "pg_namespace" ) . await ?;
@@ -626,7 +639,7 @@ impl PgReplicationClient {
626
639
publication : Option < & str > ,
627
640
) -> EtlResult < Vec < ColumnSchema > > {
628
641
let ( pub_cte, pub_pred) = if let Some ( publication) = publication {
629
- let is_pg14_or_earlier = self . is_postgres_14_or_earlier ( ) . await ? ;
642
+ let is_pg14_or_earlier = self . get_server_version ( ) . await . unwrap_or ( 0 ) < 150000 ;
630
643
631
644
if !is_pg14_or_earlier {
632
645
(
@@ -690,7 +703,7 @@ impl PgReplicationClient {
690
703
691
704
let mut column_schemas = vec ! [ ] ;
692
705
693
- for message in self . client . simple_query ( & column_info_query) . await ? {
706
+ for message in self . simple_query ( & column_info_query) . await ? {
694
707
if let SimpleQueryMessage :: Row ( row) = message {
695
708
let name = Self :: get_row_value :: < String > ( & row, "attname" , "pg_attribute" ) . await ?;
696
709
let type_oid = Self :: get_row_value :: < u32 > ( & row, "atttypid" , "pg_attribute" ) . await ?;
@@ -716,24 +729,36 @@ impl PgReplicationClient {
716
729
Ok ( column_schemas)
717
730
}
718
731
719
- async fn is_postgres_14_or_earlier ( & self ) -> EtlResult < bool > {
720
- let version_query = "SHOW server_version_num" ;
721
-
722
- for message in self . client . simple_query ( version_query) . await ? {
723
- if let SimpleQueryMessage :: Row ( row) = message {
724
- let version_str =
725
- Self :: get_row_value :: < String > ( & row, "server_version_num" , "server_version_num" )
732
+ /// Gets the PostgreSQL server version.
733
+ ///
734
+ /// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
735
+ /// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
736
+ async fn get_server_version ( & self ) -> EtlResult < i32 > {
737
+ let version = self
738
+ . client
739
+ . 1
740
+ . get_or_try_init ( || async {
741
+ let version_query = "SHOW server_version_num" ;
742
+
743
+ for message in self . simple_query ( version_query) . await ? {
744
+ if let SimpleQueryMessage :: Row ( row) = message {
745
+ let version_str = Self :: get_row_value :: < String > (
746
+ & row,
747
+ "server_version_num" ,
748
+ "server_version_num" ,
749
+ )
726
750
. await ?;
727
- let server_version: i32 = version_str. parse ( ) . unwrap_or ( 0 ) ;
751
+ let version: i32 = version_str. parse ( ) . unwrap_or ( 0 ) ;
752
+ return Ok :: < _ , EtlError > ( version) ;
753
+ }
754
+ }
728
755
729
- // PostgreSQL version format is typically: MAJOR * 10000 + MINOR * 100 + PATCH
730
- // For version 14.x.x, this would be 140000 + minor * 100 + patch
731
- // For version 15.x.x, this would be 150000 + minor * 100 + patch
732
- return Ok ( server_version < 150000 ) ;
733
- }
734
- }
756
+ // If we can't determine, return 0 (which will be treated as very old version)
757
+ Ok ( 0 )
758
+ } )
759
+ . await ?;
735
760
736
- Ok ( false )
761
+ Ok ( * version )
737
762
}
738
763
739
764
/// Creates a COPY stream for reading data from a table using its OID.
@@ -759,7 +784,7 @@ impl PgReplicationClient {
759
784
column_list
760
785
) ;
761
786
762
- let stream = self . client . copy_out_simple ( & copy_query) . await ?;
787
+ let stream = self . client . 0 . copy_out_simple ( & copy_query) . await ?;
763
788
764
789
Ok ( stream)
765
790
}
0 commit comments