Skip to content

Commit 5a2d233

Browse files
committed
feat: make orchestrator clone safe between threads
1 parent 5e32525 commit 5a2d233

File tree

3 files changed

+60
-21
lines changed

3 files changed

+60
-21
lines changed

aggregation_mode/db/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
pub mod orchestrator;
2-
mod retry;
2+
pub mod retry;
33
pub mod types;

aggregation_mode/db/src/orchestrator.rs

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
use std::{future::Future, time::Duration};
1+
use std::{
2+
future::Future,
3+
sync::{
4+
atomic::{AtomicBool, Ordering},
5+
Arc,
6+
},
7+
time::Duration,
8+
};
29

310
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
411

@@ -10,27 +17,56 @@ enum Operation {
1017
Write,
1118
}
1219

20+
/// A single DB node: connection pool plus shared health flags (used to prioritize nodes).
21+
1322
#[derive(Debug)]
1423
struct DbNode {
1524
pool: Pool<Postgres>,
16-
last_read_failed: bool,
17-
last_write_failed: bool,
25+
last_read_failed: AtomicBool,
26+
last_write_failed: AtomicBool,
1827
}
1928

20-
#[derive(Debug)]
29+
/// Database orchestrator for running reads/writes across multiple PostgreSQL nodes with retry/backoff.
30+
///
31+
/// `DbOrchestartor` holds a list of database nodes (connection pools) and will:
32+
/// - try nodes in a preferred order (healthy nodes first, then recently-failed nodes),
33+
/// - mark nodes as failed on connection-type errors,
34+
/// - retry transient failures with exponential backoff based on `retry_config`,
35+
///
36+
/// ## Thread-safe `Clone`
37+
/// This type is cheap and thread-safe to clone:
38+
/// - `nodes` is `Vec<Arc<DbNode>>`, so cloning only increments `Arc` ref-counts and shares the same pools/nodes,
39+
/// - `sqlx::Pool<Postgres>` is internally reference-counted and designed to be cloned and used concurrently,
40+
/// - the node health flags are `AtomicBool`, so updates are safe from multiple threads/tasks.
41+
///
42+
/// Clones share health state (the atomics) and the underlying pools, so all clones observe and influence
43+
/// the same “preferred node” ordering decisions.
44+
#[derive(Debug, Clone)]
2145
pub struct DbOrchestartor {
22-
nodes: Vec<DbNode>,
46+
nodes: Vec<Arc<DbNode>>,
2347
retry_config: RetryConfig,
2448
}
2549

50+
#[derive(Debug)]
2651
pub enum DbOrchestartorError {
2752
InvalidNumberOfConnectionUrls,
2853
Sqlx(sqlx::Error),
2954
}
3055

56+
impl std::fmt::Display for DbOrchestartorError {
57+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58+
match self {
59+
Self::InvalidNumberOfConnectionUrls => {
60+
write!(f, "invalid number of connection URLs")
61+
}
62+
Self::Sqlx(e) => write!(f, "{e}"),
63+
}
64+
}
65+
}
66+
3167
impl DbOrchestartor {
3268
pub fn try_new(
33-
connection_urls: Vec<String>,
69+
connection_urls: &[&str],
3470
retry_config: RetryConfig,
3571
) -> Result<Self, DbOrchestartorError> {
3672
if connection_urls.is_empty() {
@@ -40,13 +76,13 @@ impl DbOrchestartor {
4076
let nodes = connection_urls
4177
.into_iter()
4278
.map(|url| {
43-
let pool = PgPoolOptions::new().max_connections(5).connect_lazy(&url)?;
79+
let pool = PgPoolOptions::new().max_connections(5).connect_lazy(url)?;
4480

45-
Ok(DbNode {
81+
Ok(Arc::new(DbNode {
4682
pool,
47-
last_read_failed: false,
48-
last_write_failed: false,
49-
})
83+
last_read_failed: AtomicBool::new(false),
84+
last_write_failed: AtomicBool::new(false),
85+
}))
5086
})
5187
.collect::<Result<Vec<_>, sqlx::Error>>()
5288
.map_err(|e| DbOrchestartorError::Sqlx(e))?;
@@ -126,22 +162,25 @@ impl DbOrchestartor {
126162
let mut last_error = None;
127163

128164
for idx in self.preferred_order(operation) {
129-
let pool = self.nodes[idx].pool.clone();
165+
let node = &self.nodes[idx];
166+
let pool = node.pool.clone();
130167

131168
match query_fn(pool).await {
132169
Ok(res) => {
133170
match operation {
134-
Operation::Read => self.nodes[idx].last_read_failed = false,
135-
Operation::Write => self.nodes[idx].last_write_failed = false,
171+
Operation::Read => node.last_read_failed.store(false, Ordering::Relaxed),
172+
Operation::Write => node.last_write_failed.store(false, Ordering::Relaxed),
136173
};
137174
return Ok(res);
138175
}
139176
Err(err) => {
140177
if Self::is_connection_error(&err) {
141178
tracing::warn!(node_index = idx, error = ?err, "database query failed");
142179
match operation {
143-
Operation::Read => self.nodes[idx].last_read_failed = true,
144-
Operation::Write => self.nodes[idx].last_write_failed = true,
180+
Operation::Read => node.last_read_failed.store(true, Ordering::Relaxed),
181+
Operation::Write => {
182+
node.last_write_failed.store(true, Ordering::Relaxed)
183+
}
145184
};
146185
last_error = Some(err);
147186
} else {
@@ -162,8 +201,8 @@ impl DbOrchestartor {
162201

163202
for (idx, node) in self.nodes.iter().enumerate() {
164203
let failed = match operation {
165-
Operation::Read => node.last_read_failed,
166-
Operation::Write => node.last_write_failed,
204+
Operation::Read => node.last_read_failed.load(Ordering::Relaxed),
205+
Operation::Write => node.last_write_failed.load(Ordering::Relaxed),
167206
};
168207

169208
if failed {

aggregation_mode/db/src/retry.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#[derive(Debug)]
2-
pub enum RetryError<E> {
2+
pub(super) enum RetryError<E> {
33
Transient(E),
44
Permanent(E),
55
}
@@ -15,7 +15,7 @@ impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
1515

1616
impl<E: std::fmt::Display> std::error::Error for RetryError<E> where E: std::fmt::Debug {}
1717

18-
#[derive(Debug)]
18+
#[derive(Debug, Clone)]
1919
pub struct RetryConfig {
2020
/// * `min_delay_millis` - Initial delay before first retry attempt (in milliseconds)
2121
pub min_delay_millis: u64,

0 commit comments

Comments
 (0)