Skip to content

Commit 2089cc0

Browse files
Use recursive query to get oids and write test
Signed-off-by: Abhi Agarwal <[email protected]>
1 parent 3524085 commit 2089cc0

File tree

2 files changed

+81
-36
lines changed

2 files changed

+81
-36
lines changed

etl/src/replication/client.rs

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -410,47 +410,53 @@ impl PgReplicationClient {
410410
&self,
411411
publication_name: &str,
412412
) -> EtlResult<Vec<TableId>> {
413-
// Prefer pg_publication_rel (explicit tables in the publication, including partition roots)
414-
let rel_query = format!(
415-
r#"select r.prrelid as oid
416-
from pg_publication_rel r
417-
join pg_publication p on p.oid = r.prpubid
418-
where p.pubname = {}"#,
419-
quote_literal(publication_name)
420-
);
421-
422-
let mut table_ids = vec![];
423-
let mut has_rows = false;
424-
for msg in self.client.simple_query(&rel_query).await? {
425-
if let SimpleQueryMessage::Row(row) = msg {
426-
has_rows = true;
427-
let table_id =
428-
Self::get_row_value::<TableId>(&row, "oid", "pg_publication_rel").await?;
429-
table_ids.push(table_id);
430-
}
431-
}
432-
433-
if has_rows {
434-
return Ok(table_ids);
435-
}
436-
437-
// Fallback to pg_publication_tables (expanded view), used for publications like FOR ALL TABLES
438-
let publication_query = format!(
439-
"select c.oid from pg_publication_tables pt
440-
join pg_class c on c.relname = pt.tablename
441-
join pg_namespace n on n.oid = c.relnamespace AND n.nspname = pt.schemaname
442-
where pt.pubname = {};",
443-
quote_literal(publication_name)
413+
let query = format!(
414+
r#"
415+
with recursive has_rel as (
416+
select exists(
417+
select 1
418+
from pg_publication_rel r
419+
join pg_publication p on p.oid = r.prpubid
420+
where p.pubname = {pub}
421+
) as has
422+
),
423+
pub_tables as (
424+
select r.prrelid as oid
425+
from pg_publication_rel r
426+
join pg_publication p on p.oid = r.prpubid
427+
where p.pubname = {pub} and (select has from has_rel)
428+
union all
429+
select c.oid
430+
from pg_publication_tables pt
431+
join pg_class c on c.relname = pt.tablename
432+
join pg_namespace n on n.oid = c.relnamespace and n.nspname = pt.schemaname
433+
where pt.pubname = {pub} and not (select has from has_rel)
434+
),
435+
recurse(relid) as (
436+
select oid from pub_tables
437+
union all
438+
select i.inhparent
439+
from pg_inherits i
440+
join recurse r on r.relid = i.inhrelid
441+
)
442+
select distinct relid as oid
443+
from recurse r
444+
where not exists (
445+
select 1 from pg_inherits i where i.inhrelid = r.relid
446+
);
447+
"#,
448+
pub = quote_literal(publication_name)
444449
);
445450

446-
for msg in self.client.simple_query(&publication_query).await? {
451+
let mut roots = vec![];
452+
for msg in self.client.simple_query(&query).await? {
447453
if let SimpleQueryMessage::Row(row) = msg {
448454
let table_id = Self::get_row_value::<TableId>(&row, "oid", "pg_class").await?;
449-
table_ids.push(table_id);
455+
roots.push(table_id);
450456
}
451457
}
452458

453-
Ok(table_ids)
459+
Ok(roots)
454460
}
455461

456462
/// Starts a logical replication stream from the specified publication and slot.

etl/tests/replication.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#![cfg(feature = "test-utils")]
22

3+
use std::collections::HashSet;
4+
35
use etl::error::ErrorKind;
46
use etl::replication::client::PgReplicationClient;
57
use etl::test_utils::database::{spawn_source_database, test_table_name};
68
use etl::test_utils::pipeline::test_slot_name;
79
use etl::test_utils::table::assert_table_schema;
10+
use etl::test_utils::test_schema::create_partitioned_table;
811
use etl_postgres::tokio::test_utils::{TableModification, id_column_schema};
912
use etl_postgres::types::ColumnSchema;
1013
use etl_telemetry::tracing::init_test_tracing;
@@ -405,11 +408,47 @@ async fn test_publication_creation_and_check() {
405408
);
406409

407410
// We check the table ids of the tables in the publication.
408-
let table_ids = parent_client
411+
let table_ids: HashSet<_> = parent_client
409412
.get_publication_table_ids("my_publication")
413+
.await
414+
.unwrap()
415+
.into_iter()
416+
.collect();
417+
assert_eq!(table_ids, HashSet::from([table_1_id, table_2_id]));
418+
}
419+
420+
#[tokio::test(flavor = "multi_thread")]
421+
async fn test_publication_table_ids_collapse_partitioned_root() {
422+
init_test_tracing();
423+
let database = spawn_source_database().await;
424+
425+
let client = PgReplicationClient::connect(database.config.clone())
410426
.await
411427
.unwrap();
412-
assert_eq!(table_ids, vec![table_1_id, table_2_id]);
428+
429+
// We create a partitioned parent with two child partitions.
430+
let table_name = test_table_name("part_parent");
431+
let (parent_table_id, _children) = create_partitioned_table(
432+
&database,
433+
table_name.clone(),
434+
&[("p1", "from (1) to (100)"), ("p2", "from (100) to (200)")],
435+
)
436+
.await
437+
.unwrap();
438+
439+
let publication_name = "pub_part_root";
440+
database
441+
.create_publication(publication_name, std::slice::from_ref(&table_name))
442+
.await
443+
.unwrap();
444+
445+
let id = client
446+
.get_publication_table_ids(publication_name)
447+
.await
448+
.unwrap();
449+
450+
// We expect to get only the parent table id.
451+
assert_eq!(id, vec![parent_table_id]);
413452
}
414453

415454
#[tokio::test(flavor = "multi_thread")]

0 commit comments

Comments
 (0)