Skip to content

Commit a509b73

Browse files
committed
chore(ups): add postgres auto-reconnect
1 parent 5e0abe2 commit a509b73

File tree

5 files changed

+430
-47
lines changed

5 files changed

+430
-47
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/common/test-deps-docker/src/lib.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,42 @@ impl DockerRunConfig {
7373
Ok(true)
7474
}
7575

76+
pub async fn restart(&self) -> Result<()> {
77+
let container_id = self
78+
.container_id
79+
.as_ref()
80+
.ok_or_else(|| anyhow!("No container ID found, container not started"))?;
81+
82+
tracing::debug!(
83+
container_name = %self.container_name,
84+
container_id = %container_id,
85+
"restarting docker container"
86+
);
87+
88+
let output = Command::new("docker")
89+
.arg("restart")
90+
.arg(container_id)
91+
.output()
92+
.await?;
93+
94+
if !output.status.success() {
95+
let stderr = String::from_utf8_lossy(&output.stderr);
96+
anyhow::bail!(
97+
"Failed to restart container {}: {}",
98+
self.container_name,
99+
stderr
100+
);
101+
}
102+
103+
tracing::debug!(
104+
container_name = %self.container_name,
105+
container_id = %container_id,
106+
"container restarted successfully"
107+
);
108+
109+
Ok(())
110+
}
111+
76112
pub fn container_id(&self) -> Option<&str> {
77113
self.container_id.as_deref()
78114
}

packages/common/universalpubsub/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ deadpool-postgres.workspace = true
1414
futures-util.workspace = true
1515
rivet-error.workspace = true
1616
rivet-ups-protocol.workspace = true
17+
rivet-util.workspace = true
1718
serde_json.workspace = true
1819
versioned-data-util.workspace = true
1920
serde.workspace = true

packages/common/universalpubsub/src/driver/postgres/mod.rs

Lines changed: 175 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
use std::collections::HashMap;
22
use std::hash::{DefaultHasher, Hash, Hasher};
3-
use std::sync::{Arc, Mutex};
3+
use std::sync::Arc;
44

55
use anyhow::*;
66
use async_trait::async_trait;
77
use base64::Engine;
88
use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64;
99
use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime};
1010
use futures_util::future::poll_fn;
11+
use rivet_util::backoff::Backoff;
12+
use tokio::sync::{Mutex, broadcast};
1113
use tokio_postgres::{AsyncMessage, NoTls};
1214
use tracing::Instrument;
1315

@@ -17,13 +19,13 @@ use crate::pubsub::DriverOutput;
1719
#[derive(Clone)]
1820
struct Subscription {
1921
// Channel to send messages to this subscription
20-
tx: tokio::sync::broadcast::Sender<Vec<u8>>,
22+
tx: broadcast::Sender<Vec<u8>>,
2123
// Cancellation token shared by all subscribers of this subject
2224
token: tokio_util::sync::CancellationToken,
2325
}
2426

2527
impl Subscription {
26-
fn new(tx: tokio::sync::broadcast::Sender<Vec<u8>>) -> Self {
28+
fn new(tx: broadcast::Sender<Vec<u8>>) -> Self {
2729
let token = tokio_util::sync::CancellationToken::new();
2830
Self { tx, token }
2931
}
@@ -48,8 +50,9 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize =
4850
#[derive(Clone)]
4951
pub struct PostgresDriver {
5052
pool: Arc<Pool>,
51-
client: Arc<tokio_postgres::Client>,
53+
client: Arc<Mutex<Option<Arc<tokio_postgres::Client>>>>,
5254
subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
55+
client_ready: tokio::sync::watch::Receiver<bool>,
5356
}
5457

5558
impl PostgresDriver {
@@ -76,48 +79,168 @@ impl PostgresDriver {
7679

7780
let subscriptions: Arc<Mutex<HashMap<String, Subscription>>> =
7881
Arc::new(Mutex::new(HashMap::new()));
79-
let subscriptions2 = subscriptions.clone();
82+
let client: Arc<Mutex<Option<Arc<tokio_postgres::Client>>>> = Arc::new(Mutex::new(None));
8083

81-
let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?;
82-
tokio::spawn(async move {
83-
// NOTE: This loop will stop automatically when client is dropped
84-
loop {
85-
match poll_fn(|cx| conn.poll_message(cx)).await {
86-
Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => {
87-
if let Some(sub) =
88-
subscriptions2.lock().unwrap().get(note.channel()).cloned()
89-
{
90-
let bytes = match BASE64.decode(note.payload()) {
91-
std::result::Result::Ok(b) => b,
92-
std::result::Result::Err(err) => {
93-
tracing::error!(?err, "failed decoding base64");
94-
break;
95-
}
96-
};
97-
let _ = sub.tx.send(bytes);
98-
}
99-
}
100-
Some(std::result::Result::Ok(_)) => {
101-
// Ignore other async messages
84+
// Create channel for client ready notifications
85+
let (ready_tx, client_ready) = tokio::sync::watch::channel(false);
86+
87+
// Spawn connection lifecycle task
88+
tokio::spawn(Self::spawn_connection_lifecycle(
89+
conn_str.clone(),
90+
subscriptions.clone(),
91+
client.clone(),
92+
ready_tx,
93+
));
94+
95+
let driver = Self {
96+
pool: Arc::new(pool),
97+
client,
98+
subscriptions,
99+
client_ready,
100+
};
101+
102+
// Wait for initial connection to be established
103+
driver.wait_for_client().await?;
104+
105+
Ok(driver)
106+
}
107+
108+
/// Manages the connection lifecycle with automatic reconnection
109+
async fn spawn_connection_lifecycle(
110+
conn_str: String,
111+
subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
112+
client: Arc<Mutex<Option<Arc<tokio_postgres::Client>>>>,
113+
ready_tx: tokio::sync::watch::Sender<bool>,
114+
) {
115+
let mut backoff = Backoff::new(8, None, 1_000, 1_000);
116+
117+
loop {
118+
match tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await {
119+
Result::Ok((new_client, conn)) => {
120+
tracing::info!("postgres listen connection established");
121+
// Reset backoff on successful connection
122+
backoff = Backoff::new(8, None, 1_000, 1_000);
123+
124+
let new_client = Arc::new(new_client);
125+
126+
// Update the client reference immediately
127+
*client.lock().await = Some(new_client.clone());
128+
// Notify that client is ready
129+
let _ = ready_tx.send(true);
130+
131+
// Get channels to re-subscribe to
132+
let channels: Vec<String> =
133+
subscriptions.lock().await.keys().cloned().collect();
134+
let needs_resubscribe = !channels.is_empty();
135+
136+
if needs_resubscribe {
137+
tracing::debug!(
138+
?channels,
139+
"will re-subscribe to channels after connection starts"
140+
);
102141
}
103-
Some(std::result::Result::Err(err)) => {
104-
tracing::error!(?err, "async postgres error");
105-
break;
142+
143+
// Spawn a task to re-subscribe after a short delay
144+
if needs_resubscribe {
145+
let client_for_resub = new_client.clone();
146+
let channels_clone = channels.clone();
147+
tokio::spawn(async move {
148+
tracing::debug!(
149+
?channels_clone,
150+
"re-subscribing to channels after reconnection"
151+
);
152+
for channel in &channels_clone {
153+
if let Result::Err(e) = client_for_resub
154+
.execute(&format!("LISTEN \"{}\"", channel), &[])
155+
.await
156+
{
157+
tracing::error!(?e, %channel, "failed to re-subscribe to channel");
158+
} else {
159+
tracing::debug!(%channel, "successfully re-subscribed to channel");
160+
}
161+
}
162+
});
106163
}
107-
None => {
108-
tracing::debug!("async postgres connection closed");
109-
break;
164+
165+
// Poll the connection until it closes
166+
Self::poll_connection(conn, subscriptions.clone()).await;
167+
168+
// Clear the client reference on disconnect
169+
*client.lock().await = None;
170+
// Notify that client is disconnected
171+
let _ = ready_tx.send(false);
172+
}
173+
Result::Err(e) => {
174+
tracing::error!(?e, "failed to connect to postgres, retrying");
175+
backoff.tick().await;
176+
}
177+
}
178+
}
179+
}
180+
181+
/// Polls the connection for notifications until it closes or errors
182+
async fn poll_connection(
183+
mut conn: tokio_postgres::Connection<
184+
tokio_postgres::Socket,
185+
tokio_postgres::tls::NoTlsStream,
186+
>,
187+
subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
188+
) {
189+
loop {
190+
match poll_fn(|cx| conn.poll_message(cx)).await {
191+
Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => {
192+
tracing::trace!(channel = %note.channel(), "received notification");
193+
if let Some(sub) = subscriptions.lock().await.get(note.channel()).cloned() {
194+
let bytes = match BASE64.decode(note.payload()) {
195+
std::result::Result::Ok(b) => b,
196+
std::result::Result::Err(err) => {
197+
tracing::error!(?err, "failed decoding base64");
198+
continue;
199+
}
200+
};
201+
tracing::trace!(channel = %note.channel(), bytes_len = bytes.len(), "sending to broadcast channel");
202+
let _ = sub.tx.send(bytes);
203+
} else {
204+
tracing::warn!(channel = %note.channel(), "received notification for unknown channel");
110205
}
111206
}
207+
Some(std::result::Result::Ok(_)) => {
208+
// Ignore other async messages
209+
}
210+
Some(std::result::Result::Err(err)) => {
211+
tracing::error!(?err, "postgres connection error, reconnecting");
212+
break; // Exit loop to reconnect
213+
}
214+
None => {
215+
tracing::warn!("postgres connection closed, reconnecting");
216+
break; // Exit loop to reconnect
217+
}
112218
}
113-
tracing::debug!("listen connection closed");
114-
});
219+
}
220+
}
115221

116-
Ok(Self {
117-
pool: Arc::new(pool),
118-
client: Arc::new(client),
119-
subscriptions,
222+
/// Wait for the client to be connected
223+
async fn wait_for_client(&self) -> Result<Arc<tokio_postgres::Client>> {
224+
let mut ready_rx = self.client_ready.clone();
225+
tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
226+
loop {
227+
// Subscribe to changed before attempting to access client
228+
let changed_fut = ready_rx.changed();
229+
230+
// Check if client is already available
231+
if let Some(client) = self.client.lock().await.clone() {
232+
return Ok(client);
233+
}
234+
235+
// Wait for change, will return client if exists on next iteration
236+
changed_fut
237+
.await
238+
.map_err(|_| anyhow!("connection lifecycle task ended"))?;
239+
tracing::debug!("client does not exist immediately after receive ready");
240+
}
120241
})
242+
.await
243+
.map_err(|_| anyhow!("timeout waiting for postgres client connection"))?
121244
}
122245

123246
fn hash_subject(&self, subject: &str) -> String {
@@ -147,7 +270,7 @@ impl PubSubDriver for PostgresDriver {
147270

148271
// Check if we already have a subscription for this channel
149272
let (rx, drop_guard) =
150-
if let Some(existing_sub) = self.subscriptions.lock().unwrap().get(&hashed).cloned() {
273+
if let Some(existing_sub) = self.subscriptions.lock().await.get(&hashed).cloned() {
151274
// Reuse the existing broadcast channel
152275
let rx = existing_sub.tx.subscribe();
153276
let drop_guard = existing_sub.token.clone().drop_guard();
@@ -160,13 +283,15 @@ impl PubSubDriver for PostgresDriver {
160283
// Register subscription
161284
self.subscriptions
162285
.lock()
163-
.unwrap()
286+
.await
164287
.insert(hashed.clone(), subscription.clone());
165288

166289
// Execute LISTEN command on the async client (for receiving notifications)
167290
// This only needs to be done once per channel
291+
// Wait for client to be connected with retry logic
292+
let client = self.wait_for_client().await?;
168293
let span = tracing::trace_span!("pg_listen");
169-
self.client
294+
client
170295
.execute(&format!("LISTEN \"{hashed}\""), &[])
171296
.instrument(span)
172297
.await?;
@@ -179,13 +304,16 @@ impl PubSubDriver for PostgresDriver {
179304
tokio::spawn(async move {
180305
token_clone.cancelled().await;
181306
if tx_clone.receiver_count() == 0 {
182-
let sql = format!("UNLISTEN \"{}\"", hashed_clone);
183-
if let Err(err) = driver.client.execute(sql.as_str(), &[]).await {
184-
tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel");
185-
} else {
186-
tracing::trace!(%hashed_clone, "unlistened channel");
307+
let client = driver.client.lock().await.clone();
308+
if let Some(client) = client {
309+
let sql = format!("UNLISTEN \"{}\"", hashed_clone);
310+
if let Err(err) = client.execute(sql.as_str(), &[]).await {
311+
tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel");
312+
} else {
313+
tracing::trace!(%hashed_clone, "unlistened channel");
314+
}
187315
}
188-
driver.subscriptions.lock().unwrap().remove(&hashed_clone);
316+
driver.subscriptions.lock().await.remove(&hashed_clone);
189317
}
190318
});
191319

0 commit comments

Comments
 (0)