Skip to content

Commit b5db9c3

Browse files
Store server version in OnceCell
1 parent def3f2b commit b5db9c3

File tree

1 file changed

+58
-33
lines changed

1 file changed

+58
-33
lines changed

etl/src/replication/client.rs

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::collections::HashMap;
1111
use std::fmt;
1212
use std::io::BufReader;
1313
use std::sync::Arc;
14+
use tokio::sync::OnceCell;
1415
use tokio_postgres::error::SqlState;
1516
use tokio_postgres::tls::MakeTlsConnect;
1617
use tokio_postgres::{
@@ -162,7 +163,7 @@ impl PgReplicationSlotTransaction {
162163
/// and streaming changes from the database.
163164
#[derive(Debug, Clone)]
164165
pub struct PgReplicationClient {
165-
client: Arc<Client>,
166+
client: Arc<(Client, OnceCell<i32>)>,
166167
}
167168

168169
impl PgReplicationClient {
@@ -177,6 +178,15 @@ impl PgReplicationClient {
177178
}
178179
}
179180

181+
182+
// Convenience method to avoid having to access the client directly.
183+
async fn simple_query(
184+
&self,
185+
query: &str,
186+
) -> Result<Vec<SimpleQueryMessage>, tokio_postgres::Error> {
187+
self.client.0.simple_query(query).await
188+
}
189+
180190
/// Establishes a connection to Postgres without TLS encryption.
181191
///
182192
/// The connection is configured for logical replication mode.
@@ -190,7 +200,7 @@ impl PgReplicationClient {
190200
info!("successfully connected to postgres without tls");
191201

192202
Ok(PgReplicationClient {
193-
client: Arc::new(client),
203+
client: Arc::new((client, OnceCell::new())),
194204
})
195205
}
196206

@@ -221,7 +231,7 @@ impl PgReplicationClient {
221231
info!("successfully connected to postgres with tls");
222232

223233
Ok(PgReplicationClient {
224-
client: Arc::new(client),
234+
client: Arc::new((client, OnceCell::new())),
225235
})
226236
}
227237

@@ -253,7 +263,7 @@ impl PgReplicationClient {
253263
quote_literal(slot_name)
254264
);
255265

256-
let results = self.client.simple_query(&query).await?;
266+
let results = self.simple_query(&query).await?;
257267
for result in results {
258268
if let SimpleQueryMessage::Row(row) = result {
259269
let confirmed_flush_lsn = Self::get_row_value::<PgLsn>(
@@ -315,7 +325,7 @@ impl PgReplicationClient {
315325
quote_identifier(slot_name)
316326
);
317327

318-
match self.client.simple_query(&query).await {
328+
match self.simple_query(&query).await {
319329
Ok(_) => {
320330
info!("successfully deleted replication slot '{}'", slot_name);
321331

@@ -353,7 +363,7 @@ impl PgReplicationClient {
353363
"select 1 as exists from pg_publication where pubname = {};",
354364
quote_literal(publication)
355365
);
356-
for msg in self.client.simple_query(&publication_exists_query).await? {
366+
for msg in self.simple_query(&publication_exists_query).await? {
357367
if let SimpleQueryMessage::Row(_) = msg {
358368
return Ok(true);
359369
}
@@ -373,7 +383,7 @@ impl PgReplicationClient {
373383
);
374384

375385
let mut table_names = vec![];
376-
for msg in self.client.simple_query(&publication_query).await? {
386+
for msg in self.simple_query(&publication_query).await? {
377387
if let SimpleQueryMessage::Row(row) = msg {
378388
let schema =
379389
Self::get_row_value::<String>(&row, "schemaname", "pg_publication_tables")
@@ -403,7 +413,7 @@ impl PgReplicationClient {
403413
);
404414

405415
let mut table_ids = vec![];
406-
for msg in self.client.simple_query(&publication_query).await? {
416+
for msg in self.simple_query(&publication_query).await? {
407417
if let SimpleQueryMessage::Row(row) = msg {
408418
// For the sake of simplicity, we refer to the table oid as table id.
409419
let table_id = Self::get_row_value::<TableId>(&row, "oid", "pg_class").await?;
@@ -441,7 +451,11 @@ impl PgReplicationClient {
441451
options
442452
);
443453

444-
let copy_stream = self.client.copy_both_simple::<bytes::Bytes>(&query).await?;
454+
let copy_stream = self
455+
.client
456+
.0
457+
.copy_both_simple::<bytes::Bytes>(&query)
458+
.await?;
445459
let stream = LogicalReplicationStream::new(copy_stream);
446460

447461
Ok(stream)
@@ -452,23 +466,22 @@ impl PgReplicationClient {
452466
/// The transaction doesn't make any assumptions about the snapshot in use, since this is a
453467
/// concern of the statements issued within the transaction.
454468
async fn begin_tx(&self) -> EtlResult<()> {
455-
self.client
456-
.simple_query("begin read only isolation level repeatable read;")
469+
self.simple_query("begin read only isolation level repeatable read;")
457470
.await?;
458471

459472
Ok(())
460473
}
461474

462475
/// Commits the current transaction.
463476
async fn commit_tx(&self) -> EtlResult<()> {
464-
self.client.simple_query("commit;").await?;
477+
self.simple_query("commit;").await?;
465478

466479
Ok(())
467480
}
468481

469482
/// Rolls back the current transaction.
470483
async fn rollback_tx(&self) -> EtlResult<()> {
471-
self.client.simple_query("rollback;").await?;
484+
self.simple_query("rollback;").await?;
472485

473486
Ok(())
474487
}
@@ -495,7 +508,7 @@ impl PgReplicationClient {
495508
quote_identifier(slot_name),
496509
snapshot_option
497510
);
498-
match self.client.simple_query(&query).await {
511+
match self.simple_query(&query).await {
499512
Ok(results) => {
500513
for result in results {
501514
if let SimpleQueryMessage::Row(row) = result {
@@ -595,7 +608,7 @@ impl PgReplicationClient {
595608
where c.oid = {table_id}",
596609
);
597610

598-
for message in self.client.simple_query(&table_info_query).await? {
611+
for message in self.simple_query(&table_info_query).await? {
599612
if let SimpleQueryMessage::Row(row) = message {
600613
let schema_name =
601614
Self::get_row_value::<String>(&row, "schema_name", "pg_namespace").await?;
@@ -626,7 +639,7 @@ impl PgReplicationClient {
626639
publication: Option<&str>,
627640
) -> EtlResult<Vec<ColumnSchema>> {
628641
let (pub_cte, pub_pred) = if let Some(publication) = publication {
629-
let is_pg14_or_earlier = self.is_postgres_14_or_earlier().await?;
642+
let is_pg14_or_earlier = self.get_server_version().await.unwrap_or(0) < 150000;
630643

631644
if !is_pg14_or_earlier {
632645
(
@@ -690,7 +703,7 @@ impl PgReplicationClient {
690703

691704
let mut column_schemas = vec![];
692705

693-
for message in self.client.simple_query(&column_info_query).await? {
706+
for message in self.simple_query(&column_info_query).await? {
694707
if let SimpleQueryMessage::Row(row) = message {
695708
let name = Self::get_row_value::<String>(&row, "attname", "pg_attribute").await?;
696709
let type_oid = Self::get_row_value::<u32>(&row, "atttypid", "pg_attribute").await?;
@@ -716,24 +729,36 @@ impl PgReplicationClient {
716729
Ok(column_schemas)
717730
}
718731

719-
async fn is_postgres_14_or_earlier(&self) -> EtlResult<bool> {
720-
let version_query = "SHOW server_version_num";
721-
722-
for message in self.client.simple_query(version_query).await? {
723-
if let SimpleQueryMessage::Row(row) = message {
724-
let version_str =
725-
Self::get_row_value::<String>(&row, "server_version_num", "server_version_num")
732+
/// Gets the PostgreSQL server version.
733+
///
734+
/// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
735+
/// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
736+
async fn get_server_version(&self) -> EtlResult<i32> {
737+
let version = self
738+
.client
739+
.1
740+
.get_or_try_init(|| async {
741+
let version_query = "SHOW server_version_num";
742+
743+
for message in self.simple_query(version_query).await? {
744+
if let SimpleQueryMessage::Row(row) = message {
745+
let version_str = Self::get_row_value::<String>(
746+
&row,
747+
"server_version_num",
748+
"server_version_num",
749+
)
726750
.await?;
727-
let server_version: i32 = version_str.parse().unwrap_or(0);
751+
let version: i32 = version_str.parse().unwrap_or(0);
752+
return Ok::<_, EtlError>(version);
753+
}
754+
}
728755

729-
// PostgreSQL version format is typically: MAJOR * 10000 + MINOR * 100 + PATCH
730-
// For version 14.x.x, this would be 140000 + minor * 100 + patch
731-
// For version 15.x.x, this would be 150000 + minor * 100 + patch
732-
return Ok(server_version < 150000);
733-
}
734-
}
756+
// If we can't determine, return 0 (which will be treated as very old version)
757+
Ok(0)
758+
})
759+
.await?;
735760

736-
Ok(false)
761+
Ok(*version)
737762
}
738763

739764
/// Creates a COPY stream for reading data from a table using its OID.
@@ -759,7 +784,7 @@ impl PgReplicationClient {
759784
column_list
760785
);
761786

762-
let stream = self.client.copy_out_simple(&copy_query).await?;
787+
let stream = self.client.0.copy_out_simple(&copy_query).await?;
763788

764789
Ok(stream)
765790
}

0 commit comments

Comments
 (0)