Skip to content

Commit 49b53b6

Browse files
Write unit tests, disable some tests if PG<=15,
1 parent 1ca16cd commit 49b53b6

File tree

4 files changed

+194
-65
lines changed

4 files changed

+194
-65
lines changed

etl-postgres/src/replication/db.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::num::NonZeroI32;
2+
13
use etl_config::shared::{IntoConnectOptions, PgConnectionConfig};
24
use sqlx::{PgPool, Row, postgres::PgPoolOptions};
35
use thiserror::Error;
@@ -68,3 +70,107 @@ pub async fn get_table_name_from_oid(
6870
None => Err(TableLookupError::TableNotFound(table_id)),
6971
}
7072
}
73+
74+
/// Extracts the PostgreSQL server version from a version string.
75+
///
76+
/// This function parses version strings like "15.5 (Homebrew)" or "14.2"
77+
/// and converts them to the numeric format used by PostgreSQL.
78+
///
79+
/// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
80+
/// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
81+
///
82+
/// Returns `None` if the version string cannot be parsed or results in zero.
83+
pub fn extract_server_version(server_version_str: impl AsRef<str>) -> Option<NonZeroI32> {
84+
// Parse version string like "15.5 (Homebrew)" or "14.2"
85+
let version_part = server_version_str
86+
.as_ref()
87+
.split_whitespace()
88+
.next()
89+
.unwrap_or("0.0");
90+
91+
let version_components: Vec<&str> = version_part.split('.').collect();
92+
93+
let major = version_components
94+
.first()
95+
.and_then(|v| v.parse::<i32>().ok())
96+
.unwrap_or(0);
97+
let minor = version_components
98+
.get(1)
99+
.and_then(|v| v.parse::<i32>().ok())
100+
.unwrap_or(0);
101+
let patch = version_components
102+
.get(2)
103+
.and_then(|v| v.parse::<i32>().ok())
104+
.unwrap_or(0);
105+
106+
let version = major * 10000 + minor * 100 + patch;
107+
108+
NonZeroI32::new(version)
109+
}
110+
111+
#[cfg(test)]
112+
mod tests {
113+
use super::*;
114+
115+
#[test]
116+
fn test_extract_server_version_basic_versions() {
117+
assert_eq!(extract_server_version("15.5"), NonZeroI32::new(150500));
118+
assert_eq!(extract_server_version("14.2"), NonZeroI32::new(140200));
119+
assert_eq!(extract_server_version("13.0"), NonZeroI32::new(130000));
120+
assert_eq!(extract_server_version("16.1"), NonZeroI32::new(160100));
121+
}
122+
123+
#[test]
124+
fn test_extract_server_version_with_suffixes() {
125+
assert_eq!(
126+
extract_server_version("15.5 (Homebrew)"),
127+
NonZeroI32::new(150500)
128+
);
129+
assert_eq!(
130+
extract_server_version("14.2 on x86_64-pc-linux-gnu"),
131+
NonZeroI32::new(140200)
132+
);
133+
assert_eq!(
134+
extract_server_version("13.7 (Ubuntu 13.7-1.pgdg20.04+1)"),
135+
NonZeroI32::new(130700)
136+
);
137+
assert_eq!(
138+
extract_server_version("16.0 devel"),
139+
NonZeroI32::new(160000)
140+
);
141+
}
142+
143+
#[test]
144+
fn test_extract_server_version_patch_versions() {
145+
// Test versions with patch numbers
146+
assert_eq!(extract_server_version("15.5.1"), NonZeroI32::new(150501));
147+
assert_eq!(extract_server_version("14.10.3"), NonZeroI32::new(141003));
148+
assert_eq!(extract_server_version("13.12.25"), NonZeroI32::new(131225));
149+
}
150+
151+
#[test]
152+
fn test_extract_server_version_invalid_inputs() {
153+
// Test invalid inputs that should return None
154+
assert_eq!(extract_server_version(""), None);
155+
assert_eq!(extract_server_version("invalid"), None);
156+
assert_eq!(extract_server_version("not.a.version"), None);
157+
assert_eq!(extract_server_version("PostgreSQL"), None);
158+
assert_eq!(extract_server_version(" "), None);
159+
}
160+
161+
#[test]
162+
fn test_extract_server_version_zero_versions() {
163+
assert_eq!(extract_server_version("0.0.0"), None);
164+
assert_eq!(extract_server_version("0.0"), None);
165+
}
166+
167+
#[test]
168+
fn test_extract_server_version_whitespace_handling() {
169+
assert_eq!(extract_server_version(" 15.5 "), NonZeroI32::new(150500));
170+
assert_eq!(
171+
extract_server_version("15.5\t(Homebrew)"),
172+
NonZeroI32::new(150500)
173+
);
174+
assert_eq!(extract_server_version("15.5\n"), NonZeroI32::new(150500));
175+
}
176+
}

etl-postgres/src/tokio/test_utils.rs

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
use std::num::NonZeroI32;
2+
13
use etl_config::shared::{IntoConnectOptions, PgConnectionConfig};
24
use tokio::runtime::Handle;
35
use tokio_postgres::types::{ToSql, Type};
46
use tokio_postgres::{Client, GenericClient, NoTls, Transaction};
57
use tracing::info;
68

9+
use crate::replication::extract_server_version;
710
use crate::types::{ColumnSchema, TableId, TableName};
811

912
/// Table modification operations for ALTER TABLE statements.
@@ -34,10 +37,15 @@ pub enum TableModification<'a> {
3437
pub struct PgDatabase<G> {
3538
pub config: PgConnectionConfig,
3639
pub client: Option<G>,
40+
server_version: Option<NonZeroI32>,
3741
destroy_on_drop: bool,
3842
}
3943

4044
impl<G: GenericClient> PgDatabase<G> {
45+
pub fn server_version(&self) -> Option<NonZeroI32> {
46+
self.server_version
47+
}
48+
4149
/// Creates a Postgres publication for the specified tables.
4250
///
4351
/// Sets up logical replication by creating a publication that includes
@@ -71,19 +79,51 @@ impl<G: GenericClient> PgDatabase<G> {
7179
publication_name: &str,
7280
schema: Option<&str>,
7381
) -> Result<(), tokio_postgres::Error> {
74-
let create_publication_query = match schema {
75-
Some(schema_name) => format!(
76-
"create publication {} for tables in schema {}",
77-
publication_name, schema_name
78-
),
79-
None => format!("create publication {} for all tables", publication_name),
80-
};
81-
82-
self.client
83-
.as_ref()
84-
.unwrap()
85-
.execute(&create_publication_query, &[])
86-
.await?;
82+
let client = self.client.as_ref().unwrap();
83+
84+
if let Some(server_version) = self.server_version
85+
&& server_version.get() >= 150000
86+
{
87+
// PostgreSQL 15+ supports FOR ALL TABLES IN SCHEMA syntax
88+
let create_publication_query = match schema {
89+
Some(schema_name) => format!(
90+
"create publication {} for tables in schema {}",
91+
publication_name, schema_name
92+
),
93+
None => format!("create publication {} for all tables", publication_name),
94+
};
95+
96+
client.execute(&create_publication_query, &[]).await?;
97+
} else {
98+
// PostgreSQL 14 and earlier: create publication and add tables individually
99+
match schema {
100+
Some(schema_name) => {
101+
let create_pub_query = format!("create publication {}", publication_name);
102+
client.execute(&create_pub_query, &[]).await?;
103+
104+
let tables_query = format!(
105+
"select schemaname, tablename from pg_tables where schemaname = '{}'",
106+
schema_name
107+
);
108+
let rows = client.query(&tables_query, &[]).await?;
109+
110+
for row in rows {
111+
let schema: String = row.get(0);
112+
let table: String = row.get(1);
113+
let add_table_query = format!(
114+
"alter publication {} add table {}.{}",
115+
publication_name, schema, table
116+
);
117+
client.execute(&add_table_query, &[]).await?;
118+
}
119+
}
120+
None => {
121+
let create_publication_query =
122+
format!("create publication {} for all tables", publication_name);
123+
client.execute(&create_publication_query, &[]).await?;
124+
}
125+
}
126+
}
87127

88128
Ok(())
89129
}
@@ -369,7 +409,8 @@ impl PgDatabase<Client> {
369409

370410
Self {
371411
config,
372-
client: Some(client),
412+
client: Some(client.0),
413+
server_version: client.1,
373414
destroy_on_drop: true,
374415
}
375416
}
@@ -386,7 +427,8 @@ impl PgDatabase<Client> {
386427

387428
Self {
388429
config,
389-
client: Some(client),
430+
client: Some(client.0),
431+
server_version: client.1,
390432
destroy_on_drop: true,
391433
}
392434
}
@@ -401,6 +443,7 @@ impl PgDatabase<Client> {
401443
PgDatabase {
402444
config: self.config.clone(),
403445
client: Some(transaction),
446+
server_version: self.server_version,
404447
destroy_on_drop: false,
405448
}
406449
}
@@ -450,7 +493,7 @@ pub fn id_column_schema() -> ColumnSchema {
450493
///
451494
/// # Panics
452495
/// Panics if connection or database creation fails.
453-
pub async fn create_pg_database(config: &PgConnectionConfig) -> Client {
496+
pub async fn create_pg_database(config: &PgConnectionConfig) -> (Client, Option<NonZeroI32>) {
454497
// Create the database via a single connection
455498
let (client, connection) = {
456499
let config: tokio_postgres::Config = config.without_db();
@@ -474,14 +517,16 @@ pub async fn create_pg_database(config: &PgConnectionConfig) -> Client {
474517
.expect("Failed to create database");
475518

476519
// Connects to the actual Postgres database
477-
connect_to_pg_database(config).await
520+
let (client, server_version) = connect_to_pg_database(config).await;
521+
522+
(client, server_version)
478523
}
479524

480525
/// Connects to an existing Postgres database.
481526
///
482527
/// Establishes a client connection to the database specified in the configuration.
483528
/// Assumes the database already exists.
484-
pub async fn connect_to_pg_database(config: &PgConnectionConfig) -> Client {
529+
pub async fn connect_to_pg_database(config: &PgConnectionConfig) -> (Client, Option<NonZeroI32>) {
485530
// Create a new client connected to the created database
486531
let (client, connection) = {
487532
let config: tokio_postgres::Config = config.with_db();
@@ -490,6 +535,9 @@ pub async fn connect_to_pg_database(config: &PgConnectionConfig) -> Client {
490535
.await
491536
.expect("Failed to connect to Postgres")
492537
};
538+
let server_version = connection
539+
.parameter("server_version")
540+
.and_then(extract_server_version);
493541

494542
// Spawn the connection on a new task
495543
tokio::spawn(async move {
@@ -498,7 +546,7 @@ pub async fn connect_to_pg_database(config: &PgConnectionConfig) -> Client {
498546
}
499547
});
500548

501-
client
549+
(client, server_version)
502550
}
503551

504552
/// Drops a Postgres database and cleans up all resources.

etl/src/replication/client.rs

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::error::{ErrorKind, EtlError, EtlResult};
22
use crate::utils::tokio::MakeRustlsConnect;
33
use crate::{bail, etl_error};
44
use etl_config::shared::{IntoConnectOptions, PgConnectionConfig};
5+
use etl_postgres::replication::extract_server_version;
56
use etl_postgres::types::convert_type_oid_to_type;
67
use etl_postgres::types::{ColumnSchema, TableId, TableName, TableSchema};
78
use pg_escape::{quote_identifier, quote_literal};
@@ -13,7 +14,6 @@ use std::io::BufReader;
1314
use std::num::NonZeroI32;
1415
use std::sync::Arc;
1516

16-
use tokio::io::{AsyncRead, AsyncWrite};
1717
use tokio_postgres::error::SqlState;
1818
use tokio_postgres::tls::MakeTlsConnect;
1919
use tokio_postgres::{
@@ -190,7 +190,9 @@ impl PgReplicationClient {
190190

191191
let (client, connection) = config.connect(NoTls).await?;
192192

193-
let server_version = Self::extract_server_version(&connection);
193+
let server_version = connection
194+
.parameter("server_version")
195+
.and_then(extract_server_version);
194196

195197
spawn_postgres_connection::<NoTls>(connection);
196198

@@ -225,7 +227,9 @@ impl PgReplicationClient {
225227

226228
let (client, connection) = config.connect(MakeRustlsConnect::new(tls_config)).await?;
227229

228-
let server_version = Self::extract_server_version(&connection);
230+
let server_version = connection
231+
.parameter("server_version")
232+
.and_then(extract_server_version);
229233

230234
spawn_postgres_connection::<MakeRustlsConnect>(connection);
231235

@@ -732,49 +736,6 @@ impl PgReplicationClient {
732736

733737
Ok(column_schemas)
734738
}
735-
736-
/// Extracts the PostgreSQL server version from connection parameters.
737-
///
738-
/// This method should be called during connection establishment to extract
739-
/// the server version from the parameter status messages sent by the server.
740-
///
741-
/// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
742-
/// This matches the format used by `SELECT version()`.
743-
/// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
744-
fn extract_server_version<S>(connection: &Connection<Socket, S>) -> Option<NonZeroI32>
745-
where
746-
S: AsyncRead + AsyncWrite + Unpin + Send,
747-
{
748-
if let Some(server_version_str) = connection.parameter("server_version") {
749-
// Parse version string like "15.5 (Homebrew)" or "14.2"
750-
let version_part = server_version_str
751-
.split_whitespace()
752-
.next()
753-
.unwrap_or("0.0");
754-
755-
let version_components: Vec<&str> = version_part.split('.').collect();
756-
757-
let major = version_components
758-
.first()
759-
.and_then(|v| v.parse::<i32>().ok())
760-
.unwrap_or(0);
761-
let minor = version_components
762-
.get(1)
763-
.and_then(|v| v.parse::<i32>().ok())
764-
.unwrap_or(0);
765-
let patch = version_components
766-
.get(2)
767-
.and_then(|v| v.parse::<i32>().ok())
768-
.unwrap_or(0);
769-
770-
let version = major * 10000 + minor * 100 + patch;
771-
772-
NonZeroI32::new(version)
773-
} else {
774-
None
775-
}
776-
}
777-
778739
/// Creates a COPY stream for reading data from a table using its OID.
779740
///
780741
/// The stream will include only the specified columns and use text format.

etl/tests/integration/pipeline_test.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ async fn publication_changes_are_correctly_handled() {
190190
table_1_done.notified().await;
191191
table_2_done.notified().await;
192192

193+
if let Some(server_version) = database.server_version()
194+
&& server_version.get() <= 150000
195+
{
196+
println!("Skipping test for PostgreSQL version <= 15");
197+
return;
198+
}
199+
193200
// Insert one row in each table and wait for two insert events.
194201
let inserts_notify = destination
195202
.wait_for_events_count(vec![(EventType::Insert, 2)])
@@ -353,6 +360,13 @@ async fn publication_for_all_tables_in_schema_ignores_new_tables_until_restart()
353360
assert!(table_schemas.contains_key(&table_1_id));
354361
assert!(!table_schemas.contains_key(&table_2_id));
355362

363+
if let Some(server_version) = database.server_version()
364+
&& server_version.get() <= 150000
365+
{
366+
println!("Skipping test for PostgreSQL version <= 15");
367+
return;
368+
}
369+
356370
// We restart the pipeline and verify that the new table is now processed.
357371
let mut pipeline = create_pipeline(
358372
&database.config,

0 commit comments

Comments
 (0)