Skip to content

Commit d217a5c

Browse files
committed
chore(ups): handle edge cases with postgres listen/unlisten/notify when disconnected/reconnecting
1 parent 3bd57c8 commit d217a5c

File tree

6 files changed

+395
-83
lines changed

6 files changed

+395
-83
lines changed

out/errors/ups.publish_failed.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/common/test-deps-docker/src/lib.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,78 @@ impl DockerRunConfig {
109109
Ok(())
110110
}
111111

112+
pub async fn stop_container(&self) -> Result<()> {
113+
let container_id = self
114+
.container_id
115+
.as_ref()
116+
.ok_or_else(|| anyhow!("No container ID found, container not started"))?;
117+
118+
tracing::debug!(
119+
container_name = %self.container_name,
120+
container_id = %container_id,
121+
"stopping docker container"
122+
);
123+
124+
let output = Command::new("docker")
125+
.arg("stop")
126+
.arg(container_id)
127+
.output()
128+
.await?;
129+
130+
if !output.status.success() {
131+
let stderr = String::from_utf8_lossy(&output.stderr);
132+
anyhow::bail!(
133+
"Failed to stop container {}: {}",
134+
self.container_name,
135+
stderr
136+
);
137+
}
138+
139+
tracing::debug!(
140+
container_name = %self.container_name,
141+
container_id = %container_id,
142+
"container stopped successfully"
143+
);
144+
145+
Ok(())
146+
}
147+
148+
pub async fn start_container(&self) -> Result<()> {
149+
let container_id = self
150+
.container_id
151+
.as_ref()
152+
.ok_or_else(|| anyhow!("No container ID found, container not started"))?;
153+
154+
tracing::debug!(
155+
container_name = %self.container_name,
156+
container_id = %container_id,
157+
"starting docker container"
158+
);
159+
160+
let output = Command::new("docker")
161+
.arg("start")
162+
.arg(container_id)
163+
.output()
164+
.await?;
165+
166+
if !output.status.success() {
167+
let stderr = String::from_utf8_lossy(&output.stderr);
168+
anyhow::bail!(
169+
"Failed to start container {}: {}",
170+
self.container_name,
171+
stderr
172+
);
173+
}
174+
175+
tracing::debug!(
176+
container_name = %self.container_name,
177+
container_id = %container_id,
178+
"container started successfully"
179+
);
180+
181+
Ok(())
182+
}
183+
112184
pub fn container_id(&self) -> Option<&str> {
113185
self.container_id.as_deref()
114186
}

packages/common/universalpubsub/src/driver/postgres/mod.rs

Lines changed: 116 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,23 @@ impl PostgresDriver {
112112
client: Arc<Mutex<Option<Arc<tokio_postgres::Client>>>>,
113113
ready_tx: tokio::sync::watch::Sender<bool>,
114114
) {
115-
let mut backoff = Backoff::new(8, None, 1_000, 1_000);
115+
let mut backoff = Backoff::default();
116116

117117
loop {
118118
match tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await {
119119
Result::Ok((new_client, conn)) => {
120120
tracing::info!("postgres listen connection established");
121121
// Reset backoff on successful connection
122-
backoff = Backoff::new(8, None, 1_000, 1_000);
122+
backoff = Backoff::default();
123123

124124
let new_client = Arc::new(new_client);
125125

126-
// Update the client reference immediately
127-
*client.lock().await = Some(new_client.clone());
128-
// Notify that client is ready
129-
let _ = ready_tx.send(true);
126+
// Spawn the polling task immediately
127+
// This must be done before any operations on the client
128+
let subscriptions_clone = subscriptions.clone();
129+
let poll_handle = tokio::spawn(async move {
130+
Self::poll_connection(conn, subscriptions_clone).await;
131+
});
130132

131133
// Get channels to re-subscribe to
132134
let channels: Vec<String> =
@@ -135,38 +137,41 @@ impl PostgresDriver {
135137

136138
if needs_resubscribe {
137139
tracing::debug!(
138-
?channels,
140+
channels=?channels.len(),
139141
"will re-subscribe to channels after connection starts"
140142
);
141143
}
142144

143-
// Spawn a task to re-subscribe after a short delay
145+
// Re-subscribe to channels
144146
if needs_resubscribe {
145-
let client_for_resub = new_client.clone();
146-
let channels_clone = channels.clone();
147-
tokio::spawn(async move {
148-
tracing::debug!(
149-
?channels_clone,
150-
"re-subscribing to channels after reconnection"
151-
);
152-
for channel in &channels_clone {
153-
if let Result::Err(e) = client_for_resub
154-
.execute(&format!("LISTEN \"{}\"", channel), &[])
155-
.await
156-
{
157-
tracing::error!(?e, %channel, "failed to re-subscribe to channel");
158-
} else {
159-
tracing::debug!(%channel, "successfully re-subscribed to channel");
160-
}
147+
tracing::debug!(
148+
channels=?channels.len(),
149+
"re-subscribing to channels after reconnection"
150+
);
151+
for channel in &channels {
152+
tracing::info!(?channel, "re-subscribing to channel");
153+
if let Result::Err(e) = new_client
154+
.execute(&format!("LISTEN \"{}\"", channel), &[])
155+
.await
156+
{
157+
tracing::error!(?e, %channel, "failed to re-subscribe to channel");
158+
} else {
159+
tracing::debug!(%channel, "successfully re-subscribed to channel");
161160
}
162-
});
161+
}
163162
}
164163

165-
// Poll the connection until it closes
166-
Self::poll_connection(conn, subscriptions.clone()).await;
164+
// Update the client reference and signal ready
165+
// Do this AFTER re-subscribing to ensure LISTEN is complete
166+
*client.lock().await = Some(new_client.clone());
167+
let _ = ready_tx.send(true);
168+
169+
// Wait for the polling task to complete (when the connection closes)
170+
let _ = poll_handle.await;
167171

168172
// Clear the client reference on disconnect
169173
*client.lock().await = None;
174+
170175
// Notify that client is disconnected
171176
let _ = ready_tx.send(false);
172177
}
@@ -208,12 +213,12 @@ impl PostgresDriver {
208213
// Ignore other async messages
209214
}
210215
Some(std::result::Result::Err(err)) => {
211-
tracing::error!(?err, "postgres connection error, reconnecting");
212-
break; // Exit loop to reconnect
216+
tracing::error!(?err, "postgres connection error");
217+
break;
213218
}
214219
None => {
215-
tracing::warn!("postgres connection closed, reconnecting");
216-
break; // Exit loop to reconnect
220+
tracing::warn!("postgres connection closed");
221+
break;
217222
}
218223
}
219224
}
@@ -224,19 +229,16 @@ impl PostgresDriver {
224229
let mut ready_rx = self.client_ready.clone();
225230
tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
226231
loop {
227-
// Subscribe to changed before attempting to access client
228-
let changed_fut = ready_rx.changed();
229-
230232
// Check if client is already available
231233
if let Some(client) = self.client.lock().await.clone() {
232234
return Ok(client);
233235
}
234236

235-
// Wait for change, will return client if exists on next iteration
236-
changed_fut
237+
// Wait for the ready signal to change
238+
ready_rx
239+
.changed()
237240
.await
238241
.map_err(|_| anyhow!("connection lifecycle task ended"))?;
239-
tracing::debug!("client does not exist immediately after receive ready");
240242
}
241243
})
242244
.await
@@ -288,13 +290,25 @@ impl PubSubDriver for PostgresDriver {
288290

289291
// Execute LISTEN command on the async client (for receiving notifications)
290292
// This only needs to be done once per channel
291-
// Wait for client to be connected with retry logic
292-
let client = self.wait_for_client().await?;
293-
let span = tracing::trace_span!("pg_listen");
294-
client
295-
.execute(&format!("LISTEN \"{hashed}\""), &[])
296-
.instrument(span)
297-
.await?;
293+
// Try to LISTEN if client is available, but don't fail if disconnected
294+
// The reconnection logic will handle re-subscribing
295+
if let Some(client) = self.client.lock().await.clone() {
296+
let span = tracing::trace_span!("pg_listen");
297+
match client
298+
.execute(&format!("LISTEN \"{hashed}\""), &[])
299+
.instrument(span)
300+
.await
301+
{
302+
Result::Ok(_) => {
303+
tracing::debug!(%hashed, "successfully subscribed to channel");
304+
}
305+
Result::Err(e) => {
306+
tracing::warn!(?e, %hashed, "failed to LISTEN, will retry on reconnection");
307+
}
308+
}
309+
} else {
310+
tracing::debug!(%hashed, "client not connected, will LISTEN on reconnection");
311+
}
298312

299313
// Spawn a single cleanup task for this subscription waiting on its token
300314
let driver = self.clone();
@@ -333,14 +347,66 @@ impl PubSubDriver for PostgresDriver {
333347

334348
// Encode payload to base64 and send NOTIFY
335349
let encoded = BASE64.encode(payload);
336-
let conn = self.pool.get().await?;
337350
let hashed = self.hash_subject(subject);
338-
let span = tracing::trace_span!("pg_notify");
339-
conn.execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[])
340-
.instrument(span)
341-
.await?;
342351

343-
Ok(())
352+
tracing::debug!("attempting to get connection for publish");
353+
354+
// Wait for listen connection to be ready first if this channel has subscribers
355+
// This ensures that if we're reconnecting, the LISTEN is re-registered before NOTIFY
356+
if self.subscriptions.lock().await.contains_key(&hashed) {
357+
self.wait_for_client().await?;
358+
}
359+
360+
// Retry getting a connection from the pool with backoff in case the connection is
361+
// currently disconnected
362+
let mut backoff = Backoff::default();
363+
let mut last_error = None;
364+
365+
loop {
366+
match self.pool.get().await {
367+
Result::Ok(conn) => {
368+
// Test the connection with a simple query before using it
369+
match conn.execute("SELECT 1", &[]).await {
370+
Result::Ok(_) => {
371+
// Connection is good, use it for NOTIFY
372+
let span = tracing::trace_span!("pg_notify");
373+
match conn
374+
.execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[])
375+
.instrument(span)
376+
.await
377+
{
378+
Result::Ok(_) => return Ok(()),
379+
Result::Err(e) => {
380+
tracing::debug!(
381+
?e,
382+
"NOTIFY failed, retrying with new connection"
383+
);
384+
last_error = Some(e.into());
385+
}
386+
}
387+
}
388+
Result::Err(e) => {
389+
tracing::debug!(
390+
?e,
391+
"connection test failed, retrying with new connection"
392+
);
393+
last_error = Some(e.into());
394+
}
395+
}
396+
}
397+
Result::Err(e) => {
398+
tracing::debug!(?e, "failed to get connection from pool, retrying");
399+
last_error = Some(e.into());
400+
}
401+
}
402+
403+
// Check if we should continue retrying
404+
if !backoff.tick().await {
405+
return Err(
406+
last_error.unwrap_or_else(|| anyhow!("failed to publish after retries"))
407+
);
408+
}
409+
}
344410
}
345411

346412
async fn flush(&self) -> Result<()> {

packages/common/universalpubsub/src/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ use serde::{Deserialize, Serialize};
66
pub enum Ups {
77
#[error("request_timeout", "Request timeout.")]
88
RequestTimeout,
9+
#[error("publish_failed", "Failed to publish message after retries")]
10+
PublishFailed,
911
}

packages/common/universalpubsub/src/pubsub.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use tokio::sync::broadcast;
88
use tokio::sync::{RwLock, oneshot};
99
use uuid::Uuid;
1010

11+
use rivet_util::backoff::Backoff;
12+
1113
use crate::chunking::{ChunkTracker, encode_chunk, split_payload_into_chunks};
1214
use crate::driver::{PubSubDriverHandle, PublishOpts, SubscriberDriverHandle};
1315

@@ -131,7 +133,8 @@ impl PubSub {
131133
break;
132134
}
133135
} else {
134-
self.driver.publish(subject, &encoded).await?;
136+
// Use backoff when publishing through the driver
137+
self.publish_with_backoff(subject, &encoded).await?;
135138
}
136139
}
137140
Ok(())
@@ -174,7 +177,26 @@ impl PubSub {
174177
break;
175178
}
176179
} else {
177-
self.driver.publish(subject, &encoded).await?;
180+
// Use backoff when publishing through the driver
181+
self.publish_with_backoff(subject, &encoded).await?;
182+
}
183+
}
184+
Ok(())
185+
}
186+
187+
async fn publish_with_backoff(&self, subject: &str, encoded: &[u8]) -> Result<()> {
188+
let mut backoff = Backoff::default();
189+
loop {
190+
match self.driver.publish(subject, encoded).await {
191+
Result::Ok(_) => break,
192+
Err(err) if !backoff.tick().await => {
193+
tracing::info!(?err, "error publishing, cannot retry again");
194+
return Err(crate::errors::Ups::PublishFailed.build().into());
195+
}
196+
Err(err) => {
197+
tracing::info!(?err, "error publishing, retrying");
198+
// Continue retrying
199+
}
178200
}
179201
}
180202
Ok(())
@@ -293,7 +315,10 @@ impl Subscriber {
293315
pub async fn next(&mut self) -> Result<NextOutput> {
294316
loop {
295317
match self.driver.next().await? {
296-
DriverOutput::Message { subject, payload } => {
318+
DriverOutput::Message {
319+
subject: _,
320+
payload,
321+
} => {
297322
// Process chunks
298323
let mut tracker = self.pubsub.chunk_tracker.lock().unwrap();
299324
match tracker.process_chunk(&payload) {

0 commit comments

Comments
 (0)