From e4721924acdb6d4d2a24e8a82b02df62c9ae83aa Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Wed, 3 Sep 2025 14:02:03 -0700 Subject: [PATCH] fix(tunnel): fix ups race condition --- Cargo.lock | 1 + docker/dev-host/docker-compose.yml | 2 +- .../dev-multidc-multinode/docker-compose.yml | 18 +- docker/dev-multidc/docker-compose.yml | 6 +- docker/dev-multinode/docker-compose.yml | 6 +- docker/dev/docker-compose.yml | 2 +- docker/dev/rivet-engine/config.jsonc | 7 + docker/template/src/docker-compose.ts | 2 +- packages/common/config/src/config/mod.rs | 1 + packages/common/universalpubsub/Cargo.toml | 9 +- .../src/driver/postgres/mod.rs | 464 ++++++++---------- packages/common/universalpubsub/src/pubsub.rs | 6 - packages/core/pegboard-gateway/src/lib.rs | 29 +- packages/core/pegboard-tunnel/src/lib.rs | 26 +- sdks/typescript/runner/src/mod.ts | 16 +- 15 files changed, 261 insertions(+), 334 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3836093ca0..36194f04ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6042,6 +6042,7 @@ dependencies = [ "base64 0.22.1", "deadpool-postgres", "futures-util", + "moka", "rivet-config", "rivet-env", "rivet-error", diff --git a/docker/dev-host/docker-compose.yml b/docker/dev-host/docker-compose.yml index d2555d4e77..75e32cac41 100644 --- a/docker/dev-host/docker-compose.yml +++ b/docker/dev-host/docker-compose.yml @@ -150,7 +150,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 diff --git a/docker/dev-multidc-multinode/docker-compose.yml b/docker/dev-multidc-multinode/docker-compose.yml index 718843e72e..8893d32457 100644 --- a/docker/dev-multidc-multinode/docker-compose.yml +++ b/docker/dev-multidc-multinode/docker-compose.yml @@ -184,7 +184,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -224,7 +224,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -264,7 +264,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -444,7 +444,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -484,7 +484,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -524,7 +524,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -702,7 +702,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -742,7 +742,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -782,7 +782,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 diff --git a/docker/dev-multidc/docker-compose.yml b/docker/dev-multidc/docker-compose.yml index 00182d067e..ba3d226472 100644 --- a/docker/dev-multidc/docker-compose.yml +++ b/docker/dev-multidc/docker-compose.yml @@ -182,7 +182,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -330,7 +330,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -476,7 +476,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 diff --git a/docker/dev-multinode/docker-compose.yml b/docker/dev-multinode/docker-compose.yml index ad1d9a5eb7..faedb297cf 100644 --- a/docker/dev-multinode/docker-compose.yml +++ b/docker/dev-multinode/docker-compose.yml @@ -173,7 +173,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -210,7 +210,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 @@ -247,7 +247,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 diff --git a/docker/dev/docker-compose.yml b/docker/dev/docker-compose.yml index eaf6561792..6803272ebe 100644 --- a/docker/dev/docker-compose.yml +++ b/docker/dev/docker-compose.yml @@ -173,7 +173,7 @@ services: - CMD - curl - '-f' - - http://127.0.0.1:6421/health + - http://127.0.0.1:6420/health interval: 2s timeout: 10s retries: 10 diff --git a/docker/dev/rivet-engine/config.jsonc b/docker/dev/rivet-engine/config.jsonc index ba8f2b4dea..810d5be6e5 100644 --- a/docker/dev/rivet-engine/config.jsonc +++ b/docker/dev/rivet-engine/config.jsonc @@ -32,6 +32,13 @@ "postgres": { "url": "postgresql://postgres:postgres@postgres:5432/rivet_engine" }, + "postgres_notify": { + "url": "postgresql://postgres:postgres@postgres:5432/rivet_engine", + "memory_optimization": false + }, + // "memory": { + // "channel": "default" + // }, "cache": { "driver": "in_memory" }, diff --git a/docker/template/src/docker-compose.ts b/docker/template/src/docker-compose.ts index 1cdd55d105..d60cd53e27 100644 --- a/docker/template/src/docker-compose.ts +++ b/docker/template/src/docker-compose.ts @@ -294,7 +294,7 @@ export function generateDockerCompose(context: TemplateContext) { ], ports: isPrimary && i === 0 ? [`6420:6420`] : undefined, healthcheck: { - test: ["CMD", "curl", "-f", "http://127.0.0.1:6421/health"], + test: ["CMD", "curl", "-f", "http://127.0.0.1:6420/health"], interval: "2s", timeout: "10s", retries: 10, diff --git a/packages/common/config/src/config/mod.rs b/packages/common/config/src/config/mod.rs index b291de6875..b1bab1f223 100644 --- a/packages/common/config/src/config/mod.rs +++ b/packages/common/config/src/config/mod.rs @@ -193,6 +193,7 @@ impl Root { { self.pubsub = Some(PubSub::PostgresNotify(pubsub::Postgres { url: pg.url.clone(), + memory_optimization: true, })); } diff --git a/packages/common/universalpubsub/Cargo.toml b/packages/common/universalpubsub/Cargo.toml index 0627eaba30..645a3e1785 100644 --- a/packages/common/universalpubsub/Cargo.toml +++ b/packages/common/universalpubsub/Cargo.toml @@ -9,17 +9,18 @@ edition.workspace = true anyhow.workspace = true async-nats.workspace = true async-trait.workspace = true +base64.workspace = true deadpool-postgres.workspace = true futures-util.workspace = true +moka.workspace = true rivet-error.workspace = true -serde.workspace = true serde_json.workspace = true +serde.workspace = true +sha2.workspace = true +tokio-postgres.workspace = true tokio.workspace = true tracing.workspace = true uuid.workspace = true -tokio-postgres.workspace = true -base64 = "0.22" -sha2.workspace = true [dev-dependencies] rivet-config.workspace = true diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index 2d4c774058..f43143405c 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -9,6 +9,7 @@ use base64::Engine; use base64::engine::general_purpose::STANDARD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; +use moka::future::Cache; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tokio::sync::RwLock; @@ -19,6 +20,12 @@ use crate::driver::{PubSubDriver, SubscriberDriver, SubscriberDriverHandle}; use crate::errors; use crate::pubsub::{Message, NextOutput, Response}; +#[derive(Clone)] +struct Subscription { + // Channel to send requests to this subscription + tx: tokio::sync::broadcast::Sender<(Vec, Option)>, +} + // Represents a local subscription that can handle request/response struct LocalSubscription { // Channel to send requests to this subscription @@ -34,9 +41,12 @@ struct LocalRequest { #[derive(Clone)] pub struct PostgresDriver { - conn_str: String, - pool: Arc, memory_optimization: bool, + pool: Arc, + client: Arc, + + subscriptions: Cache, + // Maps subject to local subscription on this node for fast path local_subscriptions: Arc>>, } @@ -44,10 +54,10 @@ pub struct PostgresDriver { #[derive(Serialize, Deserialize)] struct Envelope { // Base64-encoded payload - p: String, - // Optional reply subject - #[serde(skip_serializing_if = "Option::is_none")] - r: Option, + #[serde(rename = "p")] + payload: String, + #[serde(rename = "r", skip_serializing_if = "Option::is_none")] + reply_subject: Option, } impl PostgresDriver { @@ -72,34 +82,57 @@ impl PostgresDriver { .context("failed to create postgres pool")?; tracing::debug!("postgres pool created successfully"); - let driver = Self { - conn_str, - pool: Arc::new(pool), - memory_optimization, - local_subscriptions: Arc::new(RwLock::new(HashMap::new())), - }; - tracing::info!("postgres driver connected successfully"); - Ok(driver) - } + let subscriptions: Cache = + Cache::builder().initial_capacity(5).build(); + let subscriptions2 = subscriptions.clone(); - fn quote_ident(subject: &str) -> String { - // Double-quote and escape any embedded quotes for safe identifier usage - let escaped = subject.replace('"', "\"\""); - format!("\"{}\"", escaped) - } + let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?; + tokio::spawn(async move { + // NOTE: This loop will stop automatically when client is dropped + loop { + match poll_fn(|cx| conn.poll_message(cx)).await { + Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { + if let Some(sub) = subscriptions2.get(note.channel()).await { + let env = match serde_json::from_str::(¬e.payload()) { + std::result::Result::Ok(env) => env, + std::result::Result::Err(err) => { + tracing::error!(?err, "failed deserializing envelope"); + break; + } + }; + let payload = match BASE64 + .decode(env.payload) + .context("invalid base64 payload") + { + std::result::Result::Ok(p) => p, + std::result::Result::Err(err) => { + tracing::error!(?err, "failed deserializing envelope"); + break; + } + }; + + let _ = sub.tx.send((payload, env.reply_subject)); + } + } + Some(std::result::Result::Ok(_)) => continue, + Some(std::result::Result::Err(err)) => { + tracing::error!(?err, "ups poll loop failed"); + break; + } + None => break, + } + } + + tracing::info!("ups poll loop stopped"); + }); - /// Convert a subject name to a PostgreSQL advisory lock ID - /// Uses SHA256 hash truncated to 63 bits to avoid collisions - fn subject_to_lock_id(subject: &str) -> i64 { - let mut hasher = Sha256::new(); - hasher.update(subject.as_bytes()); - let hash = hasher.finalize(); - - // Take first 8 bytes and convert to i64, using only 63 bits to avoid sign issues - let mut bytes = [0u8; 8]; - bytes.copy_from_slice(&hash[0..8]); - let hash_u64 = u64::from_be_bytes(bytes); - (hash_u64 & 0x7FFFFFFFFFFFFFFF) as i64 + Ok(Self { + memory_optimization, + pool: Arc::new(pool), + client: Arc::new(client), + subscriptions, + local_subscriptions: Arc::new(RwLock::new(HashMap::new())), + }) } } @@ -108,17 +141,6 @@ impl PubSubDriver for PostgresDriver { #[tracing::instrument(skip(self), fields(subject))] async fn subscribe(&self, subject: &str) -> Result { tracing::debug!(%subject, "starting subscription"); - // Get the lock ID for this subject - let lock_id = Self::subject_to_lock_id(subject); - tracing::debug!(%subject, ?lock_id, "calculated advisory lock id"); - - // Create a single connection for both subscription and lock holding - let (client, mut connection) = - tokio_postgres::connect(&self.conn_str, tokio_postgres::NoTls).await?; - - // Set up message forwarding - let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); - let subject_owned = subject.to_string(); // Set up local request handling channel if memory optimization is enabled let local_request_rx = if self.memory_optimization { @@ -131,7 +153,7 @@ impl PubSubDriver for PostgresDriver { let local_rx = subs .entry(subject.to_string()) .or_insert_with(|| LocalSubscription { - tx: tokio::sync::broadcast::channel::(64).0, + tx: tokio::sync::broadcast::channel(64).0, }) .tx .subscribe(); @@ -145,125 +167,52 @@ impl PubSubDriver for PostgresDriver { None }; - // Create channels for coordinating initialization - let (listen_done_tx, listen_done_rx) = tokio::sync::oneshot::channel(); - let (lock_done_tx, lock_done_rx) = tokio::sync::oneshot::channel(); - - // We need to wrap the client in Arc for sharing with Drop impl - let client = Arc::new(client); - let client_clone = client.clone(); - let listen_subject = subject_owned.clone(); - - // Spawn task to handle connection, lock acquisition, and LISTEN - tokio::spawn(async move { - // First acquire the lock while polling the connection - let lock_sql = format!("SELECT pg_try_advisory_lock_shared({})", lock_id); - let lock_future = client_clone.query_one(&lock_sql, &[]); - tokio::pin!(lock_future); - - let mut lock_done = false; - let mut lock_done_tx = Some(lock_done_tx); - let mut listen_started = false; - - // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes - let mut hasher = DefaultHasher::new(); - listen_subject.hash(&mut hasher); - let subject = BASE64.encode(&hasher.finish().to_be_bytes()); - - // Execute LISTEN while polling the connection - let sql = format!("LISTEN {}", Self::quote_ident(&subject)); - let listen_future = client_clone.batch_execute(&sql); - tokio::pin!(listen_future); - - let mut listen_done = false; - let mut listen_done_tx = Some(listen_done_tx); + // Get the lock ID for this subject + let lock_id = subject_to_lock_id(subject); + tracing::debug!(%subject, ?lock_id, "calculated advisory lock id"); - loop { - tokio::select! { - // First acquire the lock - result = &mut lock_future, if !lock_done => { - lock_done = true; - if let Some(tx) = lock_done_tx.take() { - let lock_acquired = result.as_ref().map(|row| row.get::<_, bool>(0)).unwrap_or(false); - let _ = tx.send(result.map(|_| lock_acquired).map_err(|e| anyhow::Error::new(e))); - } - listen_started = true; - } - // Then execute LISTEN - result = &mut listen_future, if listen_started && !listen_done => { - listen_done = true; - if let Some(tx) = listen_done_tx.take() { - let _ = tx.send(result.map_err(|e| anyhow::Error::new(e))); - } - } - // Poll messages - msg = poll_fn(|cx| connection.poll_message(cx)) => { - match msg { - Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - if note.channel() == subject { - let _ = tx.send(note.payload().to_string()); - } - } - Some(std::result::Result::Ok(_)) => continue, - Some(std::result::Result::Err(_)) => break, - None => break, - } - } + // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes + let mut hasher = DefaultHasher::new(); + subject.hash(&mut hasher); + let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); + + let rx = self + .subscriptions + .entry(subject_hash.clone()) + .or_insert_with(async { + Subscription { + tx: tokio::sync::broadcast::channel(128).0, } - } - }); - - // Wait for lock acquisition to complete - tracing::debug!(%subject, ?lock_id, "waiting for advisory lock acquisition"); - match lock_done_rx.await { - std::result::Result::Ok(std::result::Result::Ok(true)) => { - tracing::debug!(%subject, ?lock_id, "advisory lock acquired successfully"); - } - std::result::Result::Ok(std::result::Result::Ok(false)) => { - tracing::warn!( - %subject, - ?lock_id, - "failed to acquire advisory lock - another subscriber may already exist" - ); - return Err(anyhow!("Failed to acquire advisory lock for subject")); - } - std::result::Result::Ok(std::result::Result::Err(err)) => { - return Err(err); - } - std::result::Result::Err(_) => { - return Err(anyhow!("Failed to acquire lock")); - } + }) + .await + .value() + .tx + .subscribe(); + + let lock_sql = format!("SELECT pg_try_advisory_lock_shared({})", lock_id); + let lock_res = self.client.query_one(&lock_sql, &[]).await?; + let lock_acquired = lock_res.get::<_, bool>(0); + ensure!(lock_acquired, "Failed to acquire advisory lock for subject"); + + let sql = format!("LISTEN {}", quote_ident(&subject_hash)); + let listen_res = self.client.batch_execute(&sql).await; + + if listen_res.is_err() { + // Release lock on error + let _ = self + .client + .execute("SELECT pg_advisory_unlock_shared($1)", &[&lock_id]) + .await; } - // Wait for LISTEN to complete - tracing::debug!(%subject, "waiting for LISTEN command to complete"); - match listen_done_rx.await { - std::result::Result::Ok(std::result::Result::Ok(())) => { - tracing::debug!(%subject, "LISTEN command completed successfully"); - } - std::result::Result::Ok(std::result::Result::Err(err)) => { - // Release lock on error - let _ = client - .execute("SELECT pg_advisory_unlock_shared($1)", &[&lock_id]) - .await; - return Err(err); - } - std::result::Result::Err(_) => { - // Release lock on error - let _ = client - .execute("SELECT pg_advisory_unlock_shared($1)", &[&lock_id]) - .await; - return Err(anyhow!("Failed to confirm LISTEN")); - } - } + listen_res?; - tracing::info!(%subject, "subscription established successfully"); + tracing::debug!(%subject, "subscription established successfully"); Ok(Box::new(PostgresSubscriber { driver: self.clone(), rx, local_request_rx, lock_id, - client, subject: subject.to_string(), })) } @@ -281,12 +230,12 @@ impl PubSubDriver for PostgresDriver { // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes let mut hasher = DefaultHasher::new(); subject.hash(&mut hasher); - let subject = BASE64.encode(&hasher.finish().to_be_bytes()); + let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); // Encode payload let env = Envelope { - p: BASE64.encode(message), - r: None, + payload: BASE64.encode(message), + reply_subject: None, }; let payload = serde_json::to_string(&env)?; @@ -295,7 +244,7 @@ impl PubSubDriver for PostgresDriver { let escaped_payload = payload.replace('\'', "''"); let sql = format!( "NOTIFY {}, '{}'", - Self::quote_ident(&subject), + quote_ident(&subject_hash), escaped_payload ); conn.batch_execute(&sql) @@ -323,10 +272,11 @@ impl PubSubDriver for PostgresDriver { ?timeout, "starting request" ); + // Memory fast path: check if we have local subscribers first if self.memory_optimization { let subs = self.local_subscriptions.read().await; - if let Some(local_tx) = subs.get(subject) { + if let Some(local_sub) = subs.get(subject) { tracing::debug!( %subject, "using memory fast path for request" @@ -342,7 +292,7 @@ impl PubSubDriver for PostgresDriver { }; // Try to send the request - if local_tx.tx.send(request).is_ok() { + if local_sub.tx.send(request).is_ok() { // Drop early to clear lock drop(subs); @@ -383,7 +333,7 @@ impl PubSubDriver for PostgresDriver { .context("failed to get connection from pool")?; // First check if there are any listeners for this subject - let lock_id = Self::subject_to_lock_id(subject); + let lock_id = subject_to_lock_id(subject); // Check if there are any shared advisory locks (listeners) for this subject // Query pg_locks directly to avoid lock acquisition overhead @@ -419,59 +369,7 @@ impl PubSubDriver for PostgresDriver { // Create a temporary reply subject and a dedicated listener connection let reply_subject = format!("_INBOX.{}", uuid::Uuid::new_v4()); - let (client, mut connection) = - tokio_postgres::connect(&self.conn_str, tokio_postgres::NoTls).await?; - - // Setup connection and LISTEN in a task - let (listen_done_tx, listen_done_rx) = tokio::sync::oneshot::channel(); - let reply_subject_clone = reply_subject.clone(); - - // Spawn task to handle connection and LISTEN - let (response_tx, mut response_rx) = tokio::sync::mpsc::unbounded_channel(); - tokio::spawn(async move { - // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes - let mut hasher = DefaultHasher::new(); - reply_subject_clone.hash(&mut hasher); - let reply_subject = BASE64.encode(&hasher.finish().to_be_bytes()); - - // LISTEN reply subject first to avoid race - let listen_sql = format!("LISTEN {}", Self::quote_ident(&reply_subject)); - let listen_future = client.batch_execute(&listen_sql); - tokio::pin!(listen_future); - - let mut listen_done = false; - let mut listen_done_tx = Some(listen_done_tx); - - loop { - tokio::select! { - result = &mut listen_future, if !listen_done => { - listen_done = true; - if let Some(tx) = listen_done_tx.take() { - let _ = tx.send(result.map_err(|e| anyhow::Error::new(e))); - } - } - msg = poll_fn(|cx| connection.poll_message(cx)) => { - match msg { - Some(std::result::Result::Ok(tokio_postgres::AsyncMessage::Notification(note))) => { - if note.channel() == reply_subject { - let _ = response_tx.send(note.payload().to_string()); - } - } - Some(std::result::Result::Ok(_)) => continue, - Some(std::result::Result::Err(_)) => break, - None => break, - } - } - } - } - }); - - // Wait for LISTEN to complete - match listen_done_rx.await { - std::result::Result::Ok(std::result::Result::Ok(())) => {} - std::result::Result::Ok(std::result::Result::Err(err)) => return Err(err), - std::result::Result::Err(_) => return Err(anyhow!("Failed to setup LISTEN")), - } + let mut reply_sub = self.subscribe(&reply_subject).await?; // Get another connection from pool to publish the request let publish_conn = self @@ -483,33 +381,36 @@ impl PubSubDriver for PostgresDriver { // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes let mut hasher = DefaultHasher::new(); subject.hash(&mut hasher); - let subject = BASE64.encode(&hasher.finish().to_be_bytes()); + let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); + let mut hasher = DefaultHasher::new(); + reply_subject.hash(&mut hasher); + let reply_subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); // Publish request with reply subject encoded let env = Envelope { - p: BASE64.encode(payload), - r: Some(reply_subject.clone()), + payload: BASE64.encode(payload), + reply_subject: Some(reply_subject_hash.clone()), }; let env_payload = serde_json::to_string(&env)?; + // NOTIFY doesn't support parameterized queries let escaped_payload = env_payload.replace('\'', "''"); + let notify_sql = format!( "NOTIFY {}, '{}'", - Self::quote_ident(&subject), + quote_ident(&subject_hash), escaped_payload ); publish_conn.batch_execute(¬ify_sql).await?; // Wait for response with optional timeout let response_future = async { - match response_rx.recv().await { - Some(payload_str) => { - let env: Envelope = serde_json::from_str(&payload_str)?; - let bytes = BASE64.decode(env.p).context("invalid base64 payload")?; - Ok(Response { payload: bytes }) - } - None => Err(anyhow!("subscription closed")), - } + Ok(Response { + payload: match reply_sub.next().await? { + NextOutput::Message(msg) => msg.payload, + NextOutput::Unsubscribed => bail!("reply subscription unsubscribed"), + }, + }) }; // Apply timeout if specified @@ -523,6 +424,7 @@ impl PubSubDriver for PostgresDriver { } } + // NOTE: The reply argument here is already a base64 encoded hash async fn send_request_reply(&self, reply: &str, payload: &[u8]) -> Result<()> { // Get a connection from the pool let conn = self @@ -531,24 +433,15 @@ impl PubSubDriver for PostgresDriver { .await .context("failed to get connection from pool")?; - // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes - let mut hasher = DefaultHasher::new(); - reply.hash(&mut hasher); - let reply_subject = BASE64.encode(&hasher.finish().to_be_bytes()); - // Publish reply without nested reply let env = Envelope { - p: BASE64.encode(payload), - r: None, + payload: BASE64.encode(payload), + reply_subject: None, }; let payload = serde_json::to_string(&env)?; // NOTIFY doesn't support parameterized queries let escaped_payload = payload.replace('\'', "''"); - let sql = format!( - "NOTIFY {}, '{}'", - Self::quote_ident(&reply_subject), - escaped_payload - ); + let sql = format!("NOTIFY {}, '{}'", quote_ident(reply), escaped_payload); conn.batch_execute(&sql).await?; Ok(()) } @@ -594,10 +487,9 @@ impl PubSubDriver for LocalReplyDriver { pub struct PostgresSubscriber { driver: PostgresDriver, - rx: tokio::sync::mpsc::UnboundedReceiver, + rx: tokio::sync::broadcast::Receiver<(Vec, Option)>, local_request_rx: Option>, lock_id: i64, - client: Arc, subject: String, } @@ -641,19 +533,16 @@ impl SubscriberDriver for PostgresSubscriber { // Check for regular PostgreSQL messages msg = self.rx.recv() => { match msg { - Some(payload_str) => { - let env: Envelope = serde_json::from_str(&payload_str)?; - let bytes = BASE64.decode(env.p).context("invalid base64 payload")?; - - tracing::debug!(len=?bytes.len(), "received message"); + std::result::Result::Ok((payload, reply_subject)) => { + tracing::debug!(len=?payload.len(), "received message"); Ok(NextOutput::Message(Message { driver: Arc::new(self.driver.clone()), - payload: bytes, - reply: env.r, + payload, + reply: reply_subject, })) } - None => { + std::result::Result::Err(_) => { tracing::debug!(?self.subject, ?self.lock_id, "subscription closed"); Ok(NextOutput::Unsubscribed) @@ -664,19 +553,16 @@ impl SubscriberDriver for PostgresSubscriber { } else { // No memory optimization, just poll regular messages match self.rx.recv().await { - Some(payload_str) => { - let env: Envelope = serde_json::from_str(&payload_str)?; - let bytes = BASE64.decode(env.p).context("invalid base64 payload")?; - - tracing::debug!(len=?bytes.len(), "received message"); + std::result::Result::Ok((payload, reply_subject)) => { + tracing::debug!(len=?payload.len(), "received message"); Ok(NextOutput::Message(Message { driver: Arc::new(self.driver.clone()), - payload: bytes, - reply: env.r, + payload, + reply: reply_subject, })) } - None => { + std::result::Result::Err(_) => { tracing::debug!("subscription closed"); Ok(NextOutput::Unsubscribed) @@ -689,32 +575,66 @@ impl SubscriberDriver for PostgresSubscriber { impl Drop for PostgresSubscriber { fn drop(&mut self) { tracing::debug!(subject = %self.subject, ?self.lock_id, "dropping postgres subscriber"); - // Release the advisory lock when the subscriber is dropped - let lock_id = self.lock_id; - let client = self.client.clone(); - // Clean up local subscription registration if memory optimization is enabled - if self.local_request_rx.is_some() { - let subject = self.subject.clone(); - let local_subs = self.driver.local_subscriptions.clone(); + let lock_id = self.lock_id; + let driver = self.driver.clone(); + let subject = self.subject.clone(); + let has_local_rx = self.local_request_rx.is_some(); - tokio::spawn(async move { - let mut subs = local_subs.write().await; - if let Some(local_tx) = subs.get_mut(&subject) { + // Spawn a task to release the lock + tokio::spawn(async move { + // Clean up local subscription registration if memory optimization is enabled + if has_local_rx { + let mut subs = driver.local_subscriptions.write().await; + if let Some(local_sub) = subs.get_mut(&subject) { // If no more subscriptions for this subject, remove the entry - if local_tx.tx.receiver_count() == 0 { + if local_sub.tx.receiver_count() == 0 { subs.remove(&subject); } } - }); - } + } - // We need to release the lock in a blocking context since Drop is not async - // Spawn a task to release the lock - tokio::spawn(async move { - let _ = client + if let Some(sub) = driver.subscriptions.get(&subject).await { + if sub.tx.receiver_count() == 0 { + driver.subscriptions.invalidate(&subject).await; + + let mut hasher = DefaultHasher::new(); + subject.hash(&mut hasher); + let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); + + let sql = format!("UNLISTEN {}", quote_ident(&subject_hash)); + let unlisten_res = driver.client.batch_execute(&sql).await; + + if let std::result::Result::Err(err) = unlisten_res { + tracing::error!(%subject, ?err, "failed to unlisten subject"); + } + } + } + + let _ = driver + .client .execute("SELECT pg_advisory_unlock_shared($1)", &[&lock_id]) .await; }); } } + +fn quote_ident(subject: &str) -> String { + // Double-quote and escape any embedded quotes for safe identifier usage + let escaped = subject.replace('"', "\"\""); + format!("\"{}\"", escaped) +} + +/// Convert a subject name to a PostgreSQL advisory lock ID +/// Uses SHA256 hash truncated to 63 bits to avoid collisions +fn subject_to_lock_id(subject: &str) -> i64 { + let mut hasher = Sha256::new(); + hasher.update(subject.as_bytes()); + let hash = hasher.finalize(); + + // Take first 8 bytes and convert to i64, using only 63 bits to avoid sign issues + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&hash[0..8]); + let hash_u64 = u64::from_be_bytes(bytes); + (hash_u64 & 0x7FFFFFFFFFFFFFFF) as i64 +} diff --git a/packages/common/universalpubsub/src/pubsub.rs b/packages/common/universalpubsub/src/pubsub.rs index 7e650c4352..a481e79e50 100644 --- a/packages/common/universalpubsub/src/pubsub.rs +++ b/packages/common/universalpubsub/src/pubsub.rs @@ -79,9 +79,3 @@ impl Message { pub struct Response { pub payload: Vec, } - -impl Drop for Message { - fn drop(&mut self) { - tracing::info!("dropping message"); - } -} diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index 09a97856cf..32a507ea39 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -320,6 +320,21 @@ impl PegboardGateway { } } + let ups = match self.ctx.ups() { + Result::Ok(u) => u, + Err(err) => return Err((client_ws, err.into())), + }; + + // Subscribe to messages from server before informing server that a client websocket is connecting to + // prevent race conditions. + let ws_subject = + TunnelHttpWebSocketSubject::new(self.runner_id, &self.port_name, websocket_id); + let response_topic = ws_subject.to_string(); + let mut subscriber = match ups.subscribe(&response_topic).await { + Result::Ok(sub) => sub, + Err(err) => return Err((client_ws, err.into())), + }; + // Build pubsub topic let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); let topic = tunnel_subject.to_string(); @@ -345,10 +360,6 @@ impl PegboardGateway { } }; - let ups = match self.ctx.ups() { - Result::Ok(u) => u, - Err(err) => return Err((client_ws, err.into())), - }; if let Err(err) = ups .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) .await @@ -356,16 +367,6 @@ impl PegboardGateway { return Err((client_ws, err.into())); } - // Subscribe to messages from server before accepting the client websocket so that - // failures here can be retried by the proxy. - let ws_subject = - TunnelHttpWebSocketSubject::new(self.runner_id, &self.port_name, websocket_id); - let response_topic = ws_subject.to_string(); - let mut subscriber = match ups.subscribe(&response_topic).await { - Result::Ok(sub) => sub, - Err(err) => return Err((client_ws, err.into())), - }; - // Accept the WebSocket let ws_stream = match client_ws.await { Result::Ok(ws) => ws, diff --git a/packages/core/pegboard-tunnel/src/lib.rs b/packages/core/pegboard-tunnel/src/lib.rs index 0238aed4c5..4a91f93d1b 100644 --- a/packages/core/pegboard-tunnel/src/lib.rs +++ b/packages/core/pegboard-tunnel/src/lib.rs @@ -85,24 +85,14 @@ impl CustomServeTrait for PegboardTunnelCustomServe { let connections = self.connections.clone(); // Extract runner_id from query parameters - let runner_id = if let Some(query_start) = path.find('?') { - let query_string = &path[query_start + 1..]; - let params: Vec<_> = query_string.split('&').collect(); - let mut found_runner_id = None; - - for param in params { - if let Some(eq_pos) = param.find('=') { - let (key, value) = param.split_at(eq_pos); - if key == "runner_id" { - // Remove the leading '=' from value - let id_str = &value[1..]; - found_runner_id = id_str.parse::().ok(); - break; - } - } - } - - found_runner_id.unwrap_or(Id::nil()) + let runner_id = if let std::result::Result::Ok(url) = + url::Url::parse(&format!("ws://placeholder/{path}")) + { + url.query_pairs() + .find_map(|(n, v)| (n == "runner_id").then_some(v)) + .as_ref() + .and_then(|id| Id::parse(id).ok()) + .unwrap_or(Id::nil()) } else { Id::nil() }; diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 352ace7b5e..2ef6381721 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -291,10 +291,22 @@ export class Runner { pegboardWebSocket.readyState, ); - this.#sendToServer({ + // NOTE: We don't use #sendToServer here because that function checks if the runner is + // shut down + const encoded = protocol.encodeToServer({ tag: "ToServerStopping", val: null, }); + if ( + this.#pegboardWebSocket && + this.#pegboardWebSocket.readyState === WebSocket.OPEN + ) { + this.#pegboardWebSocket.send(encoded); + } else { + console.error( + "WebSocket not available or not open for sending data", + ); + } const closePromise = new Promise((resolve) => { if (!pegboardWebSocket) @@ -360,7 +372,7 @@ export class Runner { console.log("[RUNNER] Opening tunnel to:", url); console.log("[RUNNER] Current runner ID:", this.runnerId || "none"); console.log("[RUNNER] Active actors count:", this.#actors.size); - + this.#tunnel = new Tunnel(url); this.#tunnel.setCallbacks({ fetch: this.#config.fetch,