Skip to content

Commit d1a89d3

Browse files
Extract postgres version from connection parameters
1 parent e65513b commit d1a89d3

File tree

1 file changed

+72
-57
lines changed

1 file changed

+72
-57
lines changed

etl/src/replication/client.rs

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ use rustls::ClientConfig;
1010
use std::collections::HashMap;
1111
use std::fmt;
1212
use std::io::BufReader;
13+
use std::num::NonZeroI32;
1314
use std::sync::Arc;
14-
use tokio::sync::OnceCell;
15+
16+
use tokio::io::{AsyncRead, AsyncWrite};
1517
use tokio_postgres::error::SqlState;
1618
use tokio_postgres::tls::MakeTlsConnect;
1719
use tokio_postgres::{
@@ -163,7 +165,8 @@ impl PgReplicationSlotTransaction {
163165
/// and streaming changes from the database.
164166
#[derive(Debug, Clone)]
165167
pub struct PgReplicationClient {
166-
client: Arc<(Client, OnceCell<i32>)>,
168+
client: Arc<Client>,
169+
server_version: Option<NonZeroI32>,
167170
}
168171

169172
impl PgReplicationClient {
@@ -178,14 +181,6 @@ impl PgReplicationClient {
178181
}
179182
}
180183

181-
// Convenience method to avoid having to access the client directly.
182-
async fn simple_query(
183-
&self,
184-
query: &str,
185-
) -> Result<Vec<SimpleQueryMessage>, tokio_postgres::Error> {
186-
self.client.0.simple_query(query).await
187-
}
188-
189184
/// Establishes a connection to Postgres without TLS encryption.
190185
///
191186
/// The connection is configured for logical replication mode.
@@ -194,12 +189,16 @@ impl PgReplicationClient {
194189
config.replication_mode(ReplicationMode::Logical);
195190

196191
let (client, connection) = config.connect(NoTls).await?;
192+
193+
let server_version = Self::extract_server_version(&connection);
194+
197195
spawn_postgres_connection::<NoTls>(connection);
198196

199197
info!("successfully connected to postgres without tls");
200198

201199
Ok(PgReplicationClient {
202-
client: Arc::new((client, OnceCell::new())),
200+
client: Arc::new(client),
201+
server_version,
203202
})
204203
}
205204

@@ -225,12 +224,16 @@ impl PgReplicationClient {
225224
.with_no_client_auth();
226225

227226
let (client, connection) = config.connect(MakeRustlsConnect::new(tls_config)).await?;
227+
228+
let server_version = Self::extract_server_version(&connection);
229+
228230
spawn_postgres_connection::<MakeRustlsConnect>(connection);
229231

230232
info!("successfully connected to postgres with tls");
231233

232234
Ok(PgReplicationClient {
233-
client: Arc::new((client, OnceCell::new())),
235+
client: Arc::new(client),
236+
server_version,
234237
})
235238
}
236239

@@ -262,7 +265,7 @@ impl PgReplicationClient {
262265
quote_literal(slot_name)
263266
);
264267

265-
let results = self.simple_query(&query).await?;
268+
let results = self.client.simple_query(&query).await?;
266269
for result in results {
267270
if let SimpleQueryMessage::Row(row) = result {
268271
let confirmed_flush_lsn = Self::get_row_value::<PgLsn>(
@@ -324,7 +327,7 @@ impl PgReplicationClient {
324327
quote_identifier(slot_name)
325328
);
326329

327-
match self.simple_query(&query).await {
330+
match self.client.simple_query(&query).await {
328331
Ok(_) => {
329332
info!("successfully deleted replication slot '{}'", slot_name);
330333

@@ -362,7 +365,7 @@ impl PgReplicationClient {
362365
"select 1 as exists from pg_publication where pubname = {};",
363366
quote_literal(publication)
364367
);
365-
for msg in self.simple_query(&publication_exists_query).await? {
368+
for msg in self.client.simple_query(&publication_exists_query).await? {
366369
if let SimpleQueryMessage::Row(_) = msg {
367370
return Ok(true);
368371
}
@@ -382,7 +385,7 @@ impl PgReplicationClient {
382385
);
383386

384387
let mut table_names = vec![];
385-
for msg in self.simple_query(&publication_query).await? {
388+
for msg in self.client.simple_query(&publication_query).await? {
386389
if let SimpleQueryMessage::Row(row) = msg {
387390
let schema =
388391
Self::get_row_value::<String>(&row, "schemaname", "pg_publication_tables")
@@ -412,7 +415,7 @@ impl PgReplicationClient {
412415
);
413416

414417
let mut table_ids = vec![];
415-
for msg in self.simple_query(&publication_query).await? {
418+
for msg in self.client.simple_query(&publication_query).await? {
416419
if let SimpleQueryMessage::Row(row) = msg {
417420
// For the sake of simplicity, we refer to the table oid as table id.
418421
let table_id = Self::get_row_value::<TableId>(&row, "oid", "pg_class").await?;
@@ -450,11 +453,7 @@ impl PgReplicationClient {
450453
options
451454
);
452455

453-
let copy_stream = self
454-
.client
455-
.0
456-
.copy_both_simple::<bytes::Bytes>(&query)
457-
.await?;
456+
let copy_stream = self.client.copy_both_simple::<bytes::Bytes>(&query).await?;
458457
let stream = LogicalReplicationStream::new(copy_stream);
459458

460459
Ok(stream)
@@ -465,22 +464,23 @@ impl PgReplicationClient {
465464
/// The transaction doesn't make any assumptions about the snapshot in use, since this is a
466465
/// concern of the statements issued within the transaction.
467466
async fn begin_tx(&self) -> EtlResult<()> {
468-
self.simple_query("begin read only isolation level repeatable read;")
467+
self.client
468+
.simple_query("begin read only isolation level repeatable read;")
469469
.await?;
470470

471471
Ok(())
472472
}
473473

474474
/// Commits the current transaction.
475475
async fn commit_tx(&self) -> EtlResult<()> {
476-
self.simple_query("commit;").await?;
476+
self.client.simple_query("commit;").await?;
477477

478478
Ok(())
479479
}
480480

481481
/// Rolls back the current transaction.
482482
async fn rollback_tx(&self) -> EtlResult<()> {
483-
self.simple_query("rollback;").await?;
483+
self.client.simple_query("rollback;").await?;
484484

485485
Ok(())
486486
}
@@ -507,7 +507,7 @@ impl PgReplicationClient {
507507
quote_identifier(slot_name),
508508
snapshot_option
509509
);
510-
match self.simple_query(&query).await {
510+
match self.client.simple_query(&query).await {
511511
Ok(results) => {
512512
for result in results {
513513
if let SimpleQueryMessage::Row(row) = result {
@@ -607,7 +607,7 @@ impl PgReplicationClient {
607607
where c.oid = {table_id}",
608608
);
609609

610-
for message in self.simple_query(&table_info_query).await? {
610+
for message in self.client.simple_query(&table_info_query).await? {
611611
if let SimpleQueryMessage::Row(row) = message {
612612
let schema_name =
613613
Self::get_row_value::<String>(&row, "schema_name", "pg_namespace").await?;
@@ -638,7 +638,12 @@ impl PgReplicationClient {
638638
publication: Option<&str>,
639639
) -> EtlResult<Vec<ColumnSchema>> {
640640
let (pub_cte, pub_pred) = if let Some(publication) = publication {
641-
let is_pg14_or_earlier = self.get_server_version().await.unwrap_or(0) < 150000;
641+
let is_pg14_or_earlier = if let Some(server_version) = self.server_version {
642+
server_version.get() < 150000
643+
} else {
644+
// be conservative by default
645+
true
646+
};
642647

643648
if !is_pg14_or_earlier {
644649
(
@@ -702,7 +707,7 @@ impl PgReplicationClient {
702707

703708
let mut column_schemas = vec![];
704709

705-
for message in self.simple_query(&column_info_query).await? {
710+
for message in self.client.simple_query(&column_info_query).await? {
706711
if let SimpleQueryMessage::Row(row) = message {
707712
let name = Self::get_row_value::<String>(&row, "attname", "pg_attribute").await?;
708713
let type_oid = Self::get_row_value::<u32>(&row, "atttypid", "pg_attribute").await?;
@@ -728,36 +733,46 @@ impl PgReplicationClient {
728733
Ok(column_schemas)
729734
}
730735

731-
/// Gets the PostgreSQL server version.
736+
/// Extracts the PostgreSQL server version from connection parameters.
737+
///
738+
/// This method should be called during connection establishment to extract
739+
/// the server version from the parameter status messages sent by the server.
732740
///
733741
/// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
742+
/// This matches the format used by `SELECT version()`.
734743
/// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
735-
async fn get_server_version(&self) -> EtlResult<i32> {
736-
let version = self
737-
.client
738-
.1
739-
.get_or_try_init(|| async {
740-
let version_query = "SHOW server_version_num";
741-
742-
for message in self.simple_query(version_query).await? {
743-
if let SimpleQueryMessage::Row(row) = message {
744-
let version_str = Self::get_row_value::<String>(
745-
&row,
746-
"server_version_num",
747-
"server_version_num",
748-
)
749-
.await?;
750-
let version: i32 = version_str.parse().unwrap_or(0);
751-
return Ok::<_, EtlError>(version);
752-
}
753-
}
754-
755-
// If we can't determine, return 0 (which will be treated as very old version)
756-
Ok(0)
757-
})
758-
.await?;
759-
760-
Ok(*version)
744+
fn extract_server_version<S>(connection: &Connection<Socket, S>) -> Option<NonZeroI32>
745+
where
746+
S: AsyncRead + AsyncWrite + Unpin + Send,
747+
{
748+
if let Some(server_version_str) = connection.parameter("server_version") {
749+
// Parse version string like "15.5 (Homebrew)" or "14.2"
750+
let version_part = server_version_str
751+
.split_whitespace()
752+
.next()
753+
.unwrap_or("0.0");
754+
755+
let version_components: Vec<&str> = version_part.split('.').collect();
756+
757+
let major = version_components
758+
.get(0)
759+
.and_then(|v| v.parse::<i32>().ok())
760+
.unwrap_or(0);
761+
let minor = version_components
762+
.get(1)
763+
.and_then(|v| v.parse::<i32>().ok())
764+
.unwrap_or(0);
765+
let patch = version_components
766+
.get(2)
767+
.and_then(|v| v.parse::<i32>().ok())
768+
.unwrap_or(0);
769+
770+
let version = major * 10000 + minor * 100 + patch;
771+
772+
NonZeroI32::new(version)
773+
} else {
774+
None
775+
}
761776
}
762777

763778
/// Creates a COPY stream for reading data from a table using its OID.
@@ -783,7 +798,7 @@ impl PgReplicationClient {
783798
column_list
784799
);
785800

786-
let stream = self.client.0.copy_out_simple(&copy_query).await?;
801+
let stream = self.client.copy_out_simple(&copy_query).await?;
787802

788803
Ok(stream)
789804
}

0 commit comments

Comments
 (0)