diff --git a/docker/dev-host/grafana/dashboards/cache.json b/docker/dev-host/grafana/dashboards/cache.json index 5d762023ea..f5806c749e 100644 --- a/docker/dev-host/grafana/dashboards/cache.json +++ b/docker/dev-host/grafana/dashboards/cache.json @@ -1170,7 +1170,7 @@ "timepicker": {}, "timezone": "browser", "title": "Rivet Guard", - "uid": "cen785ige8fswd", + "uid": "cen785ige8fswd2", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/docker/dev-multidc-multinode/core/grafana/dashboards/cache.json b/docker/dev-multidc-multinode/core/grafana/dashboards/cache.json index 5d762023ea..f5806c749e 100644 --- a/docker/dev-multidc-multinode/core/grafana/dashboards/cache.json +++ b/docker/dev-multidc-multinode/core/grafana/dashboards/cache.json @@ -1170,7 +1170,7 @@ "timepicker": {}, "timezone": "browser", "title": "Rivet Guard", - "uid": "cen785ige8fswd", + "uid": "cen785ige8fswd2", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/docker/dev-multidc/core/grafana/dashboards/cache.json b/docker/dev-multidc/core/grafana/dashboards/cache.json index 5d762023ea..f5806c749e 100644 --- a/docker/dev-multidc/core/grafana/dashboards/cache.json +++ b/docker/dev-multidc/core/grafana/dashboards/cache.json @@ -1170,7 +1170,7 @@ "timepicker": {}, "timezone": "browser", "title": "Rivet Guard", - "uid": "cen785ige8fswd", + "uid": "cen785ige8fswd2", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/docker/dev-multinode/grafana/dashboards/cache.json b/docker/dev-multinode/grafana/dashboards/cache.json index 5d762023ea..f5806c749e 100644 --- a/docker/dev-multinode/grafana/dashboards/cache.json +++ b/docker/dev-multinode/grafana/dashboards/cache.json @@ -1170,7 +1170,7 @@ "timepicker": {}, "timezone": "browser", "title": "Rivet Guard", - "uid": "cen785ige8fswd", + "uid": "cen785ige8fswd2", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/docker/dev/grafana/dashboards/cache.json b/docker/dev/grafana/dashboards/cache.json index 5d762023ea..f5806c749e 100644 --- a/docker/dev/grafana/dashboards/cache.json +++ b/docker/dev/grafana/dashboards/cache.json @@ -1170,7 +1170,7 @@ "timepicker": {}, "timezone": "browser", "title": "Rivet Guard", - "uid": "cen785ige8fswd", + "uid": "cen785ige8fswd2", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/packages/common/config/src/config/mod.rs b/packages/common/config/src/config/mod.rs index d8aa3fbe17..b291de6875 100644 --- a/packages/common/config/src/config/mod.rs +++ b/packages/common/config/src/config/mod.rs @@ -187,15 +187,14 @@ impl Root { } pub fn validate_and_set_defaults(&mut self) -> Result<()> { - // TODO: Add back - //// Set default pubsub to Postgres if configured for database - //if self.pubsub.is_none() - // && let Some(Database::Postgres(pg)) = &self.database - //{ - // self.pubsub = Some(PubSub::PostgresNotify(pubsub::Postgres { - // url: pg.url.clone(), - // })); - //} + // Set default pubsub to Postgres if configured for database + if self.pubsub.is_none() + && let Some(Database::Postgres(pg)) = &self.database + { + self.pubsub = Some(PubSub::PostgresNotify(pubsub::Postgres { + url: pg.url.clone(), + })); + } Ok(()) } diff --git a/packages/common/config/src/config/pegboard_tunnel.rs b/packages/common/config/src/config/pegboard_tunnel.rs index 8940974e32..8be7d8f146 100644 --- a/packages/common/config/src/config/pegboard_tunnel.rs +++ b/packages/common/config/src/config/pegboard_tunnel.rs @@ -3,7 +3,7 @@ use std::net::IpAddr; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -/// The tunnel service that forwards tunnel-protocol messages between NATS and WebSocket connections. +/// The tunnel service that forwards tunnel-protocol messages between pubsub and WebSocket connections. #[derive(Debug, Serialize, Deserialize, Clone, Default, JsonSchema)] #[serde(deny_unknown_fields)] pub struct PegboardTunnel { diff --git a/packages/common/config/src/config/pubsub.rs b/packages/common/config/src/config/pubsub.rs index 81a938c1ff..e805d127e6 100644 --- a/packages/common/config/src/config/pubsub.rs +++ b/packages/common/config/src/config/pubsub.rs @@ -23,12 +23,15 @@ pub struct Postgres { /// /// See: https://docs.rs/postgres/0.19.10/postgres/config/struct.Config.html#url pub url: Secret, + #[serde(default = "default_mem_opt")] + pub memory_optimization: bool, } impl Default for Postgres { fn default() -> Self { Self { url: Secret::new("postgresql://postgres:postgres@127.0.0.1:5432/postgres".into()), + memory_optimization: true, } } } @@ -81,3 +84,7 @@ impl Memory { "default".to_string() } } + +fn default_mem_opt() -> bool { + true +} diff --git a/packages/common/gasoline/core/src/ctx/message.rs b/packages/common/gasoline/core/src/ctx/message.rs index d64ad9068d..d8228c100d 100644 --- a/packages/common/gasoline/core/src/ctx/message.rs +++ b/packages/common/gasoline/core/src/ctx/message.rs @@ -12,14 +12,14 @@ use universalpubsub::{NextOutput, Subscriber}; use crate::{ error::{WorkflowError, WorkflowResult}, - message::{Message, NatsMessage, NatsMessageWrapper}, + message::{Message, PubsubMessage, PubsubMessageWrapper}, utils::{self, tags::AsTags}, }; #[derive(Clone)] pub struct MessageCtx { - /// The connection used to communicate with NATS. - nats: UpsPool, + /// The connection used to communicate with pubsub. + pubsub: UpsPool, ray_id: Id, @@ -35,7 +35,7 @@ impl MessageCtx { ray_id: Id, ) -> WorkflowResult { Ok(MessageCtx { - nats: pools.ups().map_err(WorkflowError::PoolsGeneric)?, + pubsub: pools.ups().map_err(WorkflowError::PoolsGeneric)?, ray_id, config: config.clone(), }) @@ -44,7 +44,7 @@ impl MessageCtx { // MARK: Publishing messages impl MessageCtx { - /// Publishes a message to NATS and to a durable message stream if a topic is + /// Publishes a message to pubsub and to a durable message stream if a topic is /// set. /// /// Use `subscribe` to consume these messages ephemerally and `tail` to read @@ -94,7 +94,7 @@ impl MessageCtx { where M: Message, { - let nats_subject = M::nats_subject(); + let subject = M::subject(); let duration_since_epoch = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_else(|err| unreachable!("time is broken: {}", err)); @@ -109,7 +109,7 @@ impl MessageCtx { // Serialize message let req_id = Id::new_v1(self.config.dc_label()); - let message = NatsMessageWrapper { + let message = PubsubMessageWrapper { req_id, ray_id: self.ray_id, tags: tags.as_tags()?, @@ -119,7 +119,7 @@ impl MessageCtx { let message_buf = serde_json::to_vec(&message).map_err(WorkflowError::SerializeMessage)?; tracing::debug!( - %nats_subject, + %subject, body_bytes = ?body_buf_len, message_bytes = ?message_buf.len(), "publish message" @@ -128,15 +128,15 @@ impl MessageCtx { // It's important to write to the stream as fast as possible in order to // ensure messages are handled quickly. let message_buf = Arc::new(message_buf); - self.message_publish_nats::(&nats_subject, message_buf) + self.message_publish_pubsub::(&subject, message_buf) .await; Ok(()) } - /// Publishes the message to NATS. + /// Publishes the message to pubsub. #[tracing::instrument(level = "debug", skip_all)] - async fn message_publish_nats(&self, nats_subject: &str, message_buf: Arc>) + async fn message_publish_pubsub(&self, subject: &str, message_buf: Arc>) where M: Message, { @@ -146,19 +146,19 @@ impl MessageCtx { // Ignore for infinite backoff backoff.tick().await; - let nats_subject = nats_subject.to_owned(); + let subject = subject.to_owned(); tracing::trace!( - %nats_subject, + %subject, message_len = message_buf.len(), - "publishing message to nats" + "publishing message to pubsub" ); - if let Err(err) = self.nats.publish(&nats_subject, &(*message_buf)).await { + if let Err(err) = self.pubsub.publish(&subject, &(*message_buf)).await { tracing::warn!(?err, "publish message failed, trying again"); continue; } - tracing::debug!("publish nats message succeeded"); + tracing::debug!("publish pubsub message succeeded"); break; } } @@ -172,7 +172,7 @@ impl MessageCtx { // MARK: Subscriptions impl MessageCtx { - /// Listens for gasoline messages globally on NATS. + /// Listens for gasoline messages globally on pubsub. #[tracing::instrument(skip_all, fields(message = M::NAME))] pub async fn subscribe(&self, tags: impl AsTags) -> WorkflowResult> where @@ -180,13 +180,13 @@ impl MessageCtx { { self.subscribe_opt::(SubscribeOpts { tags: tags.as_tags()?, - flush_nats: true, + flush: true, }) .in_current_span() .await } - /// Listens for gasoline messages globally on NATS. + /// Listens for gasoline messages globally on pubsub. #[tracing::instrument(skip_all, fields(message = M::NAME))] pub async fn subscribe_opt( &self, @@ -195,24 +195,24 @@ impl MessageCtx { where M: Message, { - let nats_subject = M::nats_subject(); + let subject = M::subject(); // Create subscription and flush immediately. - tracing::debug!(%nats_subject, tags = ?opts.tags, "creating subscription"); + tracing::debug!(%subject, tags = ?opts.tags, "creating subscription"); let subscription = self - .nats - .subscribe(&nats_subject) + .pubsub + .subscribe(&subject) .await .map_err(|x| WorkflowError::CreateSubscription(x.into()))?; - if opts.flush_nats { - self.nats + if opts.flush { + self.pubsub .flush() .await - .map_err(|x| WorkflowError::FlushNats(x.into()))?; + .map_err(|x| WorkflowError::FlushPubsub(x.into()))?; } // Return handle - let subscription = SubscriptionHandle::new(nats_subject, subscription, opts.tags.clone()); + let subscription = SubscriptionHandle::new(subject, subscription, opts.tags.clone()); Ok(subscription) } } @@ -220,7 +220,7 @@ impl MessageCtx { #[derive(Debug)] pub struct SubscribeOpts { pub tags: serde_json::Value, - pub flush_nats: bool, + pub flush: bool, } /// Used to receive messages from other contexts. @@ -291,7 +291,7 @@ where /// /// This future can be safely dropped. #[tracing::instrument(name="message_next", skip_all, fields(message = M::NAME))] - pub async fn next(&mut self) -> WorkflowResult> { + pub async fn next(&mut self) -> WorkflowResult> { tracing::debug!("waiting for message"); loop { @@ -299,7 +299,7 @@ where // // Use blocking threads instead of `try_next`, since I'm not sure // try_next works as intended. - let nats_message = match self.subscription.next().await { + let message = match self.subscription.next().await { Ok(NextOutput::Message(msg)) => msg, Ok(NextOutput::Unsubscribed) => { tracing::debug!("unsubscribed"); @@ -311,11 +311,11 @@ where } }; - let message_wrapper = NatsMessage::::deserialize_wrapper(&nats_message.payload)?; + let message_wrapper = PubsubMessage::::deserialize_wrapper(&message.payload)?; // Check if the subscription tags match a subset of the message tags if utils::is_value_subset(&self.tags, &message_wrapper.tags) { - let message = NatsMessage::::deserialize_from_wrapper(message_wrapper)?; + let message = PubsubMessage::::deserialize_from_wrapper(message_wrapper)?; tracing::debug!(?message, "received message"); return Ok(message); @@ -326,7 +326,7 @@ where } /// Converts the subscription in to a stream. - pub fn into_stream(self) -> impl futures_util::Stream>> { + pub fn into_stream(self) -> impl futures_util::Stream>> { futures_util::stream::try_unfold(self, |mut sub| { async move { let message = sub.next().await?; diff --git a/packages/common/gasoline/core/src/db/kv/mod.rs b/packages/common/gasoline/core/src/db/kv/mod.rs index 04c46435cb..247b538e1b 100644 --- a/packages/common/gasoline/core/src/db/kv/mod.rs +++ b/packages/common/gasoline/core/src/db/kv/mod.rs @@ -1,4 +1,4 @@ -//! Implementation of a workflow database driver with UniversalDB and NATS. +//! Implementation of a workflow database driver with UniversalDB and UniversalPubSub. // TODO: Move code to smaller functions for readability use std::{ @@ -43,7 +43,7 @@ mod keys; const WORKER_INSTANCE_LOST_THRESHOLD_MS: i64 = rivet_util::duration::seconds(30); /// How long before overwriting an existing metrics lock. const METRICS_LOCK_TIMEOUT_MS: i64 = rivet_util::duration::seconds(30); -/// For NATS wake mechanism. +/// For pubsub wake mechanism. const WORKER_WAKE_SUBJECT: &str = "gasoline.worker.wake"; pub struct DatabaseKv { @@ -52,17 +52,17 @@ pub struct DatabaseKv { } impl DatabaseKv { - /// Spawns a new thread and publishes a worker wake message to nats. + /// Spawns a new thread and publishes a worker wake message to pubsub. fn wake_worker(&self) { - let Ok(nats) = self.pools.ups() else { - tracing::debug!("failed to acquire nats pool"); + let Ok(pubsub) = self.pools.ups() else { + tracing::debug!("failed to acquire pubsub pool"); return; }; let spawn_res = tokio::task::Builder::new().name("wake").spawn( async move { // Fail gracefully - if let Err(err) = nats.publish(WORKER_WAKE_SUBJECT, &Vec::new()).await { + if let Err(err) = pubsub.publish(WORKER_WAKE_SUBJECT, &Vec::new()).await { tracing::warn!(?err, "failed to publish wake message"); } } diff --git a/packages/common/gasoline/core/src/error.rs b/packages/common/gasoline/core/src/error.rs index 16c65c4214..7c79ae4617 100644 --- a/packages/common/gasoline/core/src/error.rs +++ b/packages/common/gasoline/core/src/error.rs @@ -110,8 +110,8 @@ pub enum WorkflowError { #[error("failed to create subscription: {0}")] CreateSubscription(#[source] anyhow::Error), - #[error("failed to flush nats: {0}")] - FlushNats(#[source] anyhow::Error), + #[error("failed to flush pubsub: {0}")] + FlushPubsub(#[source] anyhow::Error), #[error("subscription unsubscribed")] SubscriptionUnsubscribed, diff --git a/packages/common/gasoline/core/src/message.rs b/packages/common/gasoline/core/src/message.rs index 20da8916f1..4c7bfa26ca 100644 --- a/packages/common/gasoline/core/src/message.rs +++ b/packages/common/gasoline/core/src/message.rs @@ -9,14 +9,14 @@ pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static const NAME: &'static str; const TAIL_TTL: std::time::Duration; - fn nats_subject() -> String { + fn subject() -> String { format!("gasoline.msg.{}", Self::NAME) } } -/// A message received from a NATS subscription. +/// A message received from a pubsub subscription. #[derive(Debug)] -pub struct NatsMessage +pub struct PubsubMessage where M: Message, { @@ -26,19 +26,19 @@ where pub(crate) body: M, } -impl NatsMessage +impl PubsubMessage where M: Message, { #[tracing::instrument(skip_all)] pub(crate) fn deserialize_from_wrapper( - wrapper: NatsMessageWrapper<'_>, + wrapper: PubsubMessageWrapper<'_>, ) -> WorkflowResult { // Deserialize the body let body = serde_json::from_str(wrapper.body.get()) .map_err(WorkflowError::DeserializeMessageBody)?; - Ok(NatsMessage { + Ok(PubsubMessage { ray_id: wrapper.ray_id, req_id: wrapper.req_id, ts: wrapper.ts, @@ -48,12 +48,14 @@ where // Only returns the message wrapper #[tracing::instrument(skip_all)] - pub(crate) fn deserialize_wrapper<'a>(buf: &'a [u8]) -> WorkflowResult> { + pub(crate) fn deserialize_wrapper<'a>( + buf: &'a [u8], + ) -> WorkflowResult> { serde_json::from_slice(buf).map_err(WorkflowError::DeserializeMessage) } } -impl std::ops::Deref for NatsMessage +impl std::ops::Deref for PubsubMessage where M: Message, { @@ -64,7 +66,7 @@ where } } -impl NatsMessage +impl PubsubMessage where M: Message, { @@ -91,7 +93,7 @@ where } #[derive(Serialize, Deserialize)] -pub(crate) struct NatsMessageWrapper<'a> { +pub(crate) struct PubsubMessageWrapper<'a> { pub(crate) ray_id: Id, pub(crate) req_id: Id, pub(crate) tags: serde_json::Value, diff --git a/packages/common/pools/src/db/ups.rs b/packages/common/pools/src/db/ups.rs index f77067a9b1..697ed31881 100644 --- a/packages/common/pools/src/db/ups.rs +++ b/packages/common/pools/src/db/ups.rs @@ -72,7 +72,11 @@ pub async fn setup(config: Config, client_name: String) -> Result { config::PubSub::PostgresNotify(pg) => { tracing::debug!("creating postgres pubsub driver"); Arc::new( - ups::driver::postgres::PostgresDriver::connect(pg.url.read().clone(), true).await?, + ups::driver::postgres::PostgresDriver::connect( + pg.url.read().clone(), + pg.memory_optimization, + ) + .await?, ) as ups::PubSubDriverHandle } config::PubSub::Memory(memory) => { diff --git a/packages/common/pools/src/error.rs b/packages/common/pools/src/error.rs index e0ed669a4e..89b322ad1d 100644 --- a/packages/common/pools/src/error.rs +++ b/packages/common/pools/src/error.rs @@ -1,7 +1,7 @@ #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("missing nats pool")] - MissingNatsPool, + #[error("missing ups pool")] + MissingUpsPool, #[error("missing clickhouse pool")] MissingClickHousePool, diff --git a/packages/common/pools/src/pools.rs b/packages/common/pools/src/pools.rs index 5703a8648d..2f354a14ab 100644 --- a/packages/common/pools/src/pools.rs +++ b/packages/common/pools/src/pools.rs @@ -10,7 +10,7 @@ use crate::{ClickHousePool, Error, UdbPool, UpsPool}; // TODO: Automatically shutdown all pools on drop pub(crate) struct PoolsInner { pub(crate) _guard: DropGuard, - pub(crate) nats: Option, + pub(crate) ups: Option, pub(crate) clickhouse: Option, pub(crate) clickhouse_inserter: Option, pub(crate) udb: Option, @@ -27,7 +27,7 @@ impl Pools { let client_name = "rivet".to_string(); let token = CancellationToken::new(); - let (nats, udb) = tokio::try_join!( + let (ups, udb) = tokio::try_join!( crate::db::ups::setup(config.clone(), client_name.clone()), crate::db::udb::setup(config.clone()), )?; @@ -44,7 +44,7 @@ impl Pools { let pool = Pools(Arc::new(PoolsInner { _guard: token.clone().drop_guard(), - nats: Some(nats), + ups: Some(ups), clickhouse, clickhouse_inserter, udb, @@ -61,7 +61,7 @@ impl Pools { let client_name = "rivet".to_string(); let token = CancellationToken::new(); - let (nats, udb) = tokio::try_join!( + let (ups, udb) = tokio::try_join!( crate::db::ups::setup(config.clone(), client_name.clone()), crate::db::udb::setup(config.clone()), )?; @@ -69,7 +69,7 @@ impl Pools { // Test setup doesn't use ClickHouse inserter let pool = Pools(Arc::new(PoolsInner { _guard: token.clone().drop_guard(), - nats: Some(nats), + ups: Some(ups), clickhouse: None, clickhouse_inserter: None, udb, @@ -80,13 +80,13 @@ impl Pools { } // MARK: Getters - pub fn nats_option(&self) -> &Option { - &self.0.nats + pub fn ups_option(&self) -> &Option { + &self.0.ups } // MARK: Pool lookups pub fn ups(&self) -> Result { - self.0.nats.clone().ok_or(Error::MissingNatsPool.into()) + self.0.ups.clone().ok_or(Error::MissingUpsPool.into()) } pub fn clickhouse_enabled(&self) -> bool { diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index f4d489346d..2d4c774058 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::time::Duration; @@ -7,10 +8,12 @@ use async_trait::async_trait; use base64::Engine; use base64::engine::general_purpose::STANDARD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; +use futures_util::future::poll_fn; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::RwLock; use tokio_postgres::{AsyncMessage, NoTls}; +use tracing::Instrument; use crate::driver::{PubSubDriver, SubscriberDriver, SubscriberDriverHandle}; use crate::errors; @@ -19,13 +22,14 @@ use crate::pubsub::{Message, NextOutput, Response}; // Represents a local subscription that can handle request/response struct LocalSubscription { // Channel to send requests to this subscription - tx: tokio::sync::mpsc::UnboundedSender, + tx: tokio::sync::broadcast::Sender, } // Request sent to a local subscription +#[derive(Clone)] struct LocalRequest { payload: Vec, - reply_tx: tokio::sync::oneshot::Sender>, + reply_tx: tokio::sync::mpsc::Sender>, } #[derive(Clone)] @@ -33,8 +37,8 @@ pub struct PostgresDriver { conn_str: String, pool: Arc, memory_optimization: bool, - // Maps subject to local subscriptions on this node for fast path - local_subscriptions: Arc>>>, + // Maps subject to local subscription on this node for fast path + local_subscriptions: Arc>>, } #[derive(Serialize, Deserialize)] @@ -47,7 +51,9 @@ struct Envelope { } impl PostgresDriver { + #[tracing::instrument(skip(conn_str), fields(memory_optimization))] pub async fn connect(conn_str: String, memory_optimization: bool) -> Result { + tracing::debug!(?memory_optimization, "connecting to postgres"); // Create deadpool config from connection string let mut config = Config::new(); config.url = Some(conn_str.clone()); @@ -60,16 +66,20 @@ impl PostgresDriver { }); // Create the pool + tracing::debug!("creating postgres pool"); let pool = config .create_pool(Some(Runtime::Tokio1), NoTls) .context("failed to create postgres pool")?; + tracing::debug!("postgres pool created successfully"); - Ok(Self { + let driver = Self { conn_str, pool: Arc::new(pool), memory_optimization, local_subscriptions: Arc::new(RwLock::new(HashMap::new())), - }) + }; + tracing::info!("postgres driver connected successfully"); + Ok(driver) } fn quote_ident(subject: &str) -> String { @@ -95,9 +105,12 @@ impl PostgresDriver { #[async_trait] impl PubSubDriver for PostgresDriver { + #[tracing::instrument(skip(self), fields(subject))] async fn subscribe(&self, subject: &str) -> Result { + tracing::debug!(%subject, "starting subscription"); // Get the lock ID for this subject let lock_id = Self::subject_to_lock_id(subject); + tracing::debug!(%subject, ?lock_id, "calculated advisory lock id"); // Create a single connection for both subscription and lock holding let (client, mut connection) = @@ -109,13 +122,23 @@ impl PubSubDriver for PostgresDriver { // Set up local request handling channel if memory optimization is enabled let local_request_rx = if self.memory_optimization { - let (local_tx, local_rx) = tokio::sync::mpsc::unbounded_channel::(); - // Register this subscription in the local map + tracing::debug!( + %subject, + "registering local subscription for memory optimization" + ); let mut subs = self.local_subscriptions.write().await; - subs.entry(subject.to_string()) - .or_insert_with(Vec::new) - .push(LocalSubscription { tx: local_tx }); + let local_rx = subs + .entry(subject.to_string()) + .or_insert_with(|| LocalSubscription { + tx: tokio::sync::broadcast::channel::(64).0, + }) + .tx + .subscribe(); + tracing::debug!( + %subject, + "local subscription registered" + ); Some(local_rx) } else { @@ -142,15 +165,19 @@ impl PubSubDriver for PostgresDriver { let mut lock_done_tx = Some(lock_done_tx); let mut listen_started = false; + // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes + let mut hasher = DefaultHasher::new(); + listen_subject.hash(&mut hasher); + let subject = BASE64.encode(&hasher.finish().to_be_bytes()); + // Execute LISTEN while polling the connection - let sql = format!("LISTEN {}", Self::quote_ident(&listen_subject)); + let sql = format!("LISTEN {}", Self::quote_ident(&subject)); let listen_future = client_clone.batch_execute(&sql); tokio::pin!(listen_future); let mut listen_done = false; let mut listen_done_tx = Some(listen_done_tx); - use futures_util::future::poll_fn; loop { tokio::select! { // First acquire the lock @@ -173,7 +200,7 @@ impl PubSubDriver for PostgresDriver { msg = poll_fn(|cx| connection.poll_message(cx)) => { match msg { Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - if note.channel() == subject_owned { + if note.channel() == subject { let _ = tx.send(note.payload().to_string()); } } @@ -187,9 +214,17 @@ impl PubSubDriver for PostgresDriver { }); // Wait for lock acquisition to complete + tracing::debug!(%subject, ?lock_id, "waiting for advisory lock acquisition"); match lock_done_rx.await { - std::result::Result::Ok(std::result::Result::Ok(true)) => {} + std::result::Result::Ok(std::result::Result::Ok(true)) => { + tracing::debug!(%subject, ?lock_id, "advisory lock acquired successfully"); + } std::result::Result::Ok(std::result::Result::Ok(false)) => { + tracing::warn!( + %subject, + ?lock_id, + "failed to acquire advisory lock - another subscriber may already exist" + ); return Err(anyhow!("Failed to acquire advisory lock for subject")); } std::result::Result::Ok(std::result::Result::Err(err)) => { @@ -201,8 +236,11 @@ impl PubSubDriver for PostgresDriver { } // Wait for LISTEN to complete + tracing::debug!(%subject, "waiting for LISTEN command to complete"); match listen_done_rx.await { - std::result::Result::Ok(std::result::Result::Ok(())) => {} + std::result::Result::Ok(std::result::Result::Ok(())) => { + tracing::debug!(%subject, "LISTEN command completed successfully"); + } std::result::Result::Ok(std::result::Result::Err(err)) => { // Release lock on error let _ = client @@ -219,6 +257,7 @@ impl PubSubDriver for PostgresDriver { } } + tracing::info!(%subject, "subscription established successfully"); Ok(Box::new(PostgresSubscriber { driver: self.clone(), rx, @@ -229,7 +268,9 @@ impl PubSubDriver for PostgresDriver { })) } + #[tracing::instrument(skip(self, message), fields(subject, message_len = message.len()))] async fn publish(&self, subject: &str, message: &[u8]) -> Result<()> { + tracing::debug!(%subject, message_len = message.len(), "publishing message"); // Get a connection from the pool let conn = self .pool @@ -237,6 +278,11 @@ impl PubSubDriver for PostgresDriver { .await .context("failed to get connection from pool")?; + // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes + let mut hasher = DefaultHasher::new(); + subject.hash(&mut hasher); + let subject = BASE64.encode(&hasher.finish().to_be_bytes()); + // Encode payload let env = Envelope { p: BASE64.encode(message), @@ -249,10 +295,13 @@ impl PubSubDriver for PostgresDriver { let escaped_payload = payload.replace('\'', "''"); let sql = format!( "NOTIFY {}, '{}'", - Self::quote_ident(subject), + Self::quote_ident(&subject), escaped_payload ); - conn.batch_execute(&sql).await?; + conn.batch_execute(&sql) + .instrument(tracing::debug_span!("notify_execute", %subject)) + .await?; + tracing::debug!(%subject, "message published successfully"); Ok(()) } @@ -261,63 +310,71 @@ impl PubSubDriver for PostgresDriver { Ok(()) } + #[tracing::instrument(skip(self, payload), fields(subject, payload_len = payload.len(), ?timeout))] async fn request( &self, subject: &str, payload: &[u8], timeout: Option, ) -> Result { + tracing::debug!( + %subject, + payload_len = payload.len(), + ?timeout, + "starting request" + ); // Memory fast path: check if we have local subscribers first if self.memory_optimization { let subs = self.local_subscriptions.read().await; - if let Some(local_subs) = subs.get(subject) { - if !local_subs.is_empty() { - tracing::debug!("using memory fast path for request on subject: {}", subject); - // Use the first available local subscription - let local_sub = &local_subs[0]; - - // Create a channel for the reply - let (reply_tx, reply_rx) = tokio::sync::oneshot::channel(); - - // Send the request to the local subscription - let request = LocalRequest { - payload: payload.to_vec(), - reply_tx, + if let Some(local_tx) = subs.get(subject) { + tracing::debug!( + %subject, + "using memory fast path for request" + ); + + // Create a channel for the reply + let (reply_tx, mut reply_rx) = tokio::sync::mpsc::channel(1); + + // Send the request to the local subscription + let request = LocalRequest { + payload: payload.to_vec(), + reply_tx, + }; + + // Try to send the request + if local_tx.tx.send(request).is_ok() { + // Drop early to clear lock + drop(subs); + + // Wait for response with optional timeout + let response_future = async { + match reply_rx.recv().await { + Some(response_payload) => Ok(Response { + payload: response_payload, + }), + None => Err(anyhow!("local subscription closed")), + } }; - // Try to send the request - if local_sub.tx.send(request).is_ok() { - // Wait for response with optional timeout - let response_future = async { - match reply_rx.await { - std::result::Result::Ok(response_payload) => Ok(Response { - payload: response_payload, - }), - std::result::Result::Err(_) => { - Err(anyhow!("local subscription closed")) - } + // Apply timeout if specified + if let Some(dur) = timeout { + return match tokio::time::timeout(dur, response_future).await { + std::result::Result::Ok(resp) => resp, + std::result::Result::Err(_) => { + Err(errors::Ups::RequestTimeout.build().into()) } }; - - // Apply timeout if specified - if let Some(dur) = timeout { - return match tokio::time::timeout(dur, response_future).await { - std::result::Result::Ok(resp) => resp, - std::result::Result::Err(_) => { - Err(errors::Ups::RequestTimeout.build().into()) - } - }; - } else { - return response_future.await; - } + } else { + return response_future.await; } - // If send failed, the subscription might be dead, clean it up later - // and fall through to normal path } + // If send failed, the subscription might be dead, clean it up later + // and fall through to normal path } } // Normal path: check for listeners via database + tracing::debug!(%subject, "checking for remote listeners via database"); // Get a connection from the pool for checking listeners let conn = self .pool @@ -345,8 +402,14 @@ impl PubSubDriver for PostgresDriver { "; let row = conn.query_one(check_sql, &[&classid, &objid]).await?; let has_listeners: bool = row.get(0); + tracing::debug!( + %subject, + ?has_listeners, + "checked for listeners in database" + ); if !has_listeners { + tracing::warn!(%subject, "no listeners found for subject"); return Err(errors::Ups::NoResponders.build().into()); } @@ -366,15 +429,19 @@ impl PubSubDriver for PostgresDriver { // Spawn task to handle connection and LISTEN let (response_tx, mut response_rx) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { + // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes + let mut hasher = DefaultHasher::new(); + reply_subject_clone.hash(&mut hasher); + let reply_subject = BASE64.encode(&hasher.finish().to_be_bytes()); + // LISTEN reply subject first to avoid race - let listen_sql = format!("LISTEN {}", Self::quote_ident(&reply_subject_clone)); + let listen_sql = format!("LISTEN {}", Self::quote_ident(&reply_subject)); let listen_future = client.batch_execute(&listen_sql); tokio::pin!(listen_future); let mut listen_done = false; let mut listen_done_tx = Some(listen_done_tx); - use futures_util::future::poll_fn; loop { tokio::select! { result = &mut listen_future, if !listen_done => { @@ -386,7 +453,7 @@ impl PubSubDriver for PostgresDriver { msg = poll_fn(|cx| connection.poll_message(cx)) => { match msg { Some(std::result::Result::Ok(tokio_postgres::AsyncMessage::Notification(note))) => { - if note.channel() == reply_subject_clone { + if note.channel() == reply_subject { let _ = response_tx.send(note.payload().to_string()); } } @@ -413,6 +480,11 @@ impl PubSubDriver for PostgresDriver { .await .context("failed to get connection from pool")?; + // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes + let mut hasher = DefaultHasher::new(); + subject.hash(&mut hasher); + let subject = BASE64.encode(&hasher.finish().to_be_bytes()); + // Publish request with reply subject encoded let env = Envelope { p: BASE64.encode(payload), @@ -423,7 +495,7 @@ impl PubSubDriver for PostgresDriver { let escaped_payload = env_payload.replace('\'', "''"); let notify_sql = format!( "NOTIFY {}, '{}'", - Self::quote_ident(subject), + Self::quote_ident(&subject), escaped_payload ); publish_conn.batch_execute(¬ify_sql).await?; @@ -459,6 +531,11 @@ impl PubSubDriver for PostgresDriver { .await .context("failed to get connection from pool")?; + // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes + let mut hasher = DefaultHasher::new(); + reply.hash(&mut hasher); + let reply_subject = BASE64.encode(&hasher.finish().to_be_bytes()); + // Publish reply without nested reply let env = Envelope { p: BASE64.encode(payload), @@ -467,7 +544,11 @@ impl PubSubDriver for PostgresDriver { let payload = serde_json::to_string(&env)?; // NOTIFY doesn't support parameterized queries let escaped_payload = payload.replace('\'', "''"); - let sql = format!("NOTIFY {}, '{}'", Self::quote_ident(reply), escaped_payload); + let sql = format!( + "NOTIFY {}, '{}'", + Self::quote_ident(&reply_subject), + escaped_payload + ); conn.batch_execute(&sql).await?; Ok(()) } @@ -475,7 +556,7 @@ impl PubSubDriver for PostgresDriver { // Special driver for handling local replies struct LocalReplyDriver { - reply_tx: Arc>>>>, + reply_tx: tokio::sync::mpsc::Sender>, } #[async_trait] @@ -502,11 +583,11 @@ impl PubSubDriver for LocalReplyDriver { } async fn send_request_reply(&self, _reply: &str, payload: &[u8]) -> Result<()> { + tracing::debug!("sending local request reply"); + // Send the reply through the local channel - let mut tx_opt = self.reply_tx.lock().await; - if let Some(tx) = tx_opt.take() { - let _ = tx.send(payload.to_vec()); - } + let _ = self.reply_tx.send(payload.to_vec()).await; + Ok(()) } } @@ -514,7 +595,7 @@ impl PubSubDriver for LocalReplyDriver { pub struct PostgresSubscriber { driver: PostgresDriver, rx: tokio::sync::mpsc::UnboundedReceiver, - local_request_rx: Option>, + local_request_rx: Option>, lock_id: i64, client: Arc, subject: String, @@ -522,22 +603,27 @@ pub struct PostgresSubscriber { #[async_trait] impl SubscriberDriver for PostgresSubscriber { + #[tracing::instrument(skip(self), fields(subject = %self.subject, lock_id = %self.lock_id))] async fn next(&mut self) -> Result { + tracing::debug!("waiting for message"); + // If we have a local request receiver, poll both channels if let Some(ref mut local_rx) = self.local_request_rx { tokio::select! { // Check for local requests (memory fast path) local_req = local_rx.recv() => { match local_req { - Some(req) => { + std::result::Result::Ok(req) => { // Create a synthetic reply subject for local request let reply_subject = format!("_LOCAL.{}", uuid::Uuid::new_v4()); // Create a wrapper driver that will handle the reply let local_driver = LocalReplyDriver { - reply_tx: Arc::new(Mutex::new(Some(req.reply_tx))), + reply_tx: req.reply_tx, }; + tracing::debug!(len=?req.payload.len(), "received local message"); + // Return the request as a message with the local reply driver Ok(NextOutput::Message(Message { driver: Arc::new(local_driver), @@ -545,7 +631,11 @@ impl SubscriberDriver for PostgresSubscriber { reply: Some(reply_subject), })) } - None => Ok(NextOutput::Unsubscribed), + std::result::Result::Err(_) => { + tracing::debug!("no local subscription senders"); + + Ok(NextOutput::Unsubscribed) + } } } // Check for regular PostgreSQL messages @@ -554,13 +644,20 @@ impl SubscriberDriver for PostgresSubscriber { Some(payload_str) => { let env: Envelope = serde_json::from_str(&payload_str)?; let bytes = BASE64.decode(env.p).context("invalid base64 payload")?; + + tracing::debug!(len=?bytes.len(), "received message"); + Ok(NextOutput::Message(Message { driver: Arc::new(self.driver.clone()), payload: bytes, reply: env.r, })) } - None => Ok(NextOutput::Unsubscribed), + None => { + tracing::debug!(?self.subject, ?self.lock_id, "subscription closed"); + + Ok(NextOutput::Unsubscribed) + } } } } @@ -570,13 +667,20 @@ impl SubscriberDriver for PostgresSubscriber { Some(payload_str) => { let env: Envelope = serde_json::from_str(&payload_str)?; let bytes = BASE64.decode(env.p).context("invalid base64 payload")?; + + tracing::debug!(len=?bytes.len(), "received message"); + Ok(NextOutput::Message(Message { driver: Arc::new(self.driver.clone()), payload: bytes, reply: env.r, })) } - None => Ok(NextOutput::Unsubscribed), + None => { + tracing::debug!("subscription closed"); + + Ok(NextOutput::Unsubscribed) + } } } } @@ -584,27 +688,21 @@ impl SubscriberDriver for PostgresSubscriber { impl Drop for PostgresSubscriber { fn drop(&mut self) { + tracing::debug!(subject = %self.subject, ?self.lock_id, "dropping postgres subscriber"); // Release the advisory lock when the subscriber is dropped let lock_id = self.lock_id; let client = self.client.clone(); // Clean up local subscription registration if memory optimization is enabled - if self.driver.memory_optimization { + if self.local_request_rx.is_some() { let subject = self.subject.clone(); let local_subs = self.driver.local_subscriptions.clone(); tokio::spawn(async move { let mut subs = local_subs.write().await; - // Remove this subscription from the local map - // Note: In a real implementation, we'd need to track which specific - // subscription to remove, perhaps using an ID - if let Some(subject_subs) = subs.get_mut(&subject) { - // For now, just remove the first one (this is a simplification) - if !subject_subs.is_empty() { - subject_subs.remove(0); - } + if let Some(local_tx) = subs.get_mut(&subject) { // If no more subscriptions for this subject, remove the entry - if subject_subs.is_empty() { + if local_tx.tx.receiver_count() == 0 { subs.remove(&subject); } } diff --git a/packages/common/universalpubsub/src/pubsub.rs b/packages/common/universalpubsub/src/pubsub.rs index a481e79e50..7e650c4352 100644 --- a/packages/common/universalpubsub/src/pubsub.rs +++ b/packages/common/universalpubsub/src/pubsub.rs @@ -79,3 +79,9 @@ impl Message { pub struct Response { pub payload: Vec, } + +impl Drop for Message { + fn drop(&mut self) { + tracing::info!("dropping message"); + } +} diff --git a/packages/common/universalpubsub/tests/integration.rs b/packages/common/universalpubsub/tests/integration.rs index 7d80ab9595..0883295c73 100644 --- a/packages/common/universalpubsub/tests/integration.rs +++ b/packages/common/universalpubsub/tests/integration.rs @@ -1,4 +1,5 @@ use anyhow::*; +use futures_util::StreamExt; use rivet_error::RivetError; use rivet_test_deps_docker::{TestDatabase, TestPubSub}; use std::{ @@ -58,7 +59,7 @@ async fn test_postgres_driver_with_memory() { let (db_config, docker_config) = TestDatabase::Postgres.config(test_id, 1).await.unwrap(); let mut docker = docker_config.unwrap(); docker.start().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; let rivet_config::config::Database::Postgres(pg) = db_config else { unreachable!(); @@ -81,7 +82,7 @@ async fn test_postgres_driver_without_memory() { let (db_config, docker_config) = TestDatabase::Postgres.config(test_id, 1).await.unwrap(); let mut docker = docker_config.unwrap(); docker.start().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; let rivet_config::config::Database::Postgres(pg) = db_config else { unreachable!(); @@ -129,6 +130,10 @@ async fn test_inner(pubsub: &PubSub) { test_request_response(&pubsub).await.unwrap(); tracing::info!(duration_ms = ?start.elapsed().as_millis(), "test_request_response completed"); + let start = Instant::now(); + test_multiple_request_response(&pubsub).await.unwrap(); + tracing::info!(duration_ms = ?start.elapsed().as_millis(), "test_multiple_request_response completed"); + let start = Instant::now(); test_request_timeout(&pubsub).await.unwrap(); tracing::info!(duration_ms = ?start.elapsed().as_millis(), "test_request_timeout completed"); @@ -269,6 +274,42 @@ async fn test_request_response(pubsub: &PubSub) -> Result<()> { Ok(()) } +async fn test_multiple_request_response(pubsub: &PubSub) -> Result<()> { + tracing::info!("testing multiple request/response"); + + // TODO: This fails on postgres for high numbers (too many clients) + futures_util::stream::iter(0..5) + .map(|i| async move { + let mut payload = b"request payload ".to_vec(); + payload.extend(i.to_string().as_bytes()); + + { + let pubsub = pubsub.clone(); + let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + let mut sub = pubsub.subscribe("test.request_response").await.unwrap(); + ready_tx.send(()).unwrap(); + while let NextOutput::Message(msg) = sub.next().await.unwrap() { + // Reply with the same payload back + let _ = msg.reply(&msg.payload).await; + } + }); + ready_rx.await.unwrap(); + } + + let req = pubsub + .request("test.request_response", &payload) + .await + .unwrap(); + assert_eq!(req.payload, payload); + }) + .buffer_unordered(50) + .collect::<()>() + .await; + + Ok(()) +} + async fn test_request_timeout(pubsub: &PubSub) -> Result<()> { tracing::info!("testing request timeout"); diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index 5602cc5a68..09a97856cf 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -172,16 +172,16 @@ impl PegboardGateway { let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message)) .map_err(|e| anyhow!("failed to serialize message: {}", e))?; - // Build NATS topic + // Build pubsub topic let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); let topic = tunnel_subject.to_string(); tracing::info!( - ?topic, + %topic, ?self.runner_id, - ?self.port_name, + %self.port_name, ?request_id, - "publishing request to NATS" + "publishing request to pubsub" ); // Create response channel @@ -208,16 +208,16 @@ impl PegboardGateway { tracing::info!("starting response handler task"); while let ResultOk(NextOutput::Message(msg)) = subscriber.next().await { // Ack message - //match msg.reply(&[]).await { - // Result::Ok(_) => {} - // Err(err) => { - // tracing::warn!(?err, "failed to ack gateway request response message") - // } - //}; + match msg.reply(&[]).await { + Result::Ok(_) => {} + Err(err) => { + tracing::warn!(?err, "failed to ack gateway request response message") + } + }; tracing::info!( payload_len = msg.payload.len(), - "received response from NATS" + "received response from pubsub" ); if let ResultOk(tunnel_msg) = versioned::TunnelMessage::deserialize(&msg.payload) { match tunnel_msg.body { @@ -235,11 +235,11 @@ impl PegboardGateway { } } _ => { - tracing::warn!("received non-response message from NATS"); + tracing::warn!("received non-response message from pubsub"); } } } else { - tracing::error!("failed to deserialize response from NATS"); + tracing::error!("failed to deserialize response from pubsub"); } } tracing::info!("response handler task ended"); @@ -248,7 +248,7 @@ impl PegboardGateway { // Publish request self.ctx .ups()? - .publish(&topic, &serialized) + .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) .await .map_err(|e| anyhow!("failed to publish request: {}", e))?; @@ -264,7 +264,7 @@ impl PegboardGateway { .map_err(|e| anyhow!("failed to serialize finish message: {}", e))?; self.ctx .ups()? - .publish(&topic, &finish_serialized) + .request_with_timeout(&topic, &finish_serialized, UPS_REQ_TIMEOUT) .await .map_err(|e| anyhow!("failed to publish finish message: {}", e))?; @@ -320,7 +320,7 @@ impl PegboardGateway { } } - // Build NATS topic + // Build pubsub topic let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); let topic = tunnel_subject.to_string(); @@ -349,7 +349,10 @@ impl PegboardGateway { Result::Ok(u) => u, Err(err) => return Err((client_ws, err.into())), }; - if let Err(err) = ups.publish(&topic, &serialized).await { + if let Err(err) = ups + .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + .await + { return Err((client_ws, err.into())); } @@ -378,12 +381,12 @@ impl PegboardGateway { tokio::spawn(async move { while let ResultOk(NextOutput::Message(msg)) = subscriber.next().await { // Ack message - //match msg.reply(&[]).await { - // Result::Ok(_) => {} - // Err(err) => { - // tracing::warn!(?err, "failed to ack gateway websocket message") - // } - //}; + match msg.reply(&[]).await { + Result::Ok(_) => {} + Err(err) => { + tracing::warn!(?err, "failed to ack gateway websocket message") + } + }; if let ResultOk(tunnel_msg) = versioned::TunnelMessage::deserialize(&msg.payload) { match tunnel_msg.body { @@ -424,7 +427,10 @@ impl PegboardGateway { Result::Ok(s) => s, Err(_) => break, }; - if let Err(err) = ups.publish(&topic, &serialized).await { + if let Err(err) = ups + .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + .await + { if is_tunnel_closed_error(&err) { tracing::warn!("tunnel closed sending binary message"); close_reason = Some("Tunnel closed".to_string()); @@ -448,7 +454,10 @@ impl PegboardGateway { Result::Ok(s) => s, Err(_) => break, }; - if let Err(err) = ups.publish(&topic, &serialized).await { + if let Err(err) = ups + .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + .await + { if is_tunnel_closed_error(&err) { tracing::warn!("tunnel closed sending text message"); close_reason = Some("Tunnel closed".to_string()); @@ -479,7 +488,10 @@ impl PegboardGateway { Err(_) => Vec::new(), }; - if let Err(err) = ups.publish(&topic, &serialized).await { + if let Err(err) = ups + .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + .await + { if is_tunnel_closed_error(&err) { tracing::warn!("tunnel closed sending close message"); } else { diff --git a/packages/core/pegboard-tunnel/src/lib.rs b/packages/core/pegboard-tunnel/src/lib.rs index 07a126cb14..0238aed4c5 100644 --- a/packages/core/pegboard-tunnel/src/lib.rs +++ b/packages/core/pegboard-tunnel/src/lib.rs @@ -118,10 +118,10 @@ impl CustomServeTrait for PegboardTunnelCustomServe { let connection_id = Id::nil(); - // Subscribe to NATS topic for this runner before accepting the client websocket so + // Subscribe to pubsub topic for this runner before accepting the client websocket so // that failures can be retried by the proxy. let topic = TunnelHttpRunnerSubject::new(runner_id, &port_name).to_string(); - info!(?topic, ?runner_id, "subscribing to NATS topic"); + info!(%topic, ?runner_id, "subscribing to pubsub topic"); let mut sub = match ups.subscribe(&topic).await { Result::Ok(s) => s, @@ -153,43 +153,43 @@ impl CustomServeTrait for PegboardTunnelCustomServe { .insert(connection_id, connection.clone()); // Handle bidirectional message forwarding - let ws_write_nats_to_ws = ws_write.clone(); + let ws_write_pubsub_to_ws = ws_write.clone(); let connections_clone = connections.clone(); let ups_clone = ups.clone(); - // Task for forwarding NATS -> WebSocket - let nats_to_ws = tokio::spawn(async move { - info!("starting NATS to WebSocket forwarding task"); + // Task for forwarding pubsub -> WebSocket + let pubsub_to_ws = tokio::spawn(async move { + info!("starting pubsub to WebSocket forwarding task"); while let ::std::result::Result::Ok(NextOutput::Message(msg)) = sub.next().await { // Ack message - //match msg.reply(&[]).await { - // Result::Ok(_) => {} - // Err(err) => { - // tracing::warn!(?err, "failed to ack gateway request response message") - // } - //}; + match msg.reply(&[]).await { + Result::Ok(_) => {} + Err(err) => { + tracing::warn!(?err, "failed to ack gateway request response message") + } + }; info!( payload_len = msg.payload.len(), - "received message from NATS, forwarding to WebSocket" + "received message from pubsub, forwarding to WebSocket" ); // Forward raw message to WebSocket let ws_msg = WsMessage::Binary(msg.payload.to_vec().into()); { - let mut stream = ws_write_nats_to_ws.lock().await; + let mut stream = ws_write_pubsub_to_ws.lock().await; if let Err(e) = stream.send(ws_msg).await { error!(?e, "failed to send message to WebSocket"); break; } } } - info!("NATS to WebSocket forwarding task ended"); + info!("pubsub to WebSocket forwarding task ended"); }); - // Task for forwarding WebSocket -> NATS - let ws_write_ws_to_nats = ws_write.clone(); - let ws_to_nats = tokio::spawn(async move { - info!("starting WebSocket to NATS forwarding task"); + // Task for forwarding WebSocket -> pubsub + let ws_write_ws_to_pubsub = ws_write.clone(); + let ws_to_pubsub = tokio::spawn(async move { + info!("starting WebSocket to pubsub forwarding task"); while let Some(msg) = ws_read.next().await { match msg { ::std::result::Result::Ok(WsMessage::Binary(data)) => { @@ -203,7 +203,7 @@ impl CustomServeTrait for PegboardTunnelCustomServe { // Handle different message types match &tunnel_msg.body { MessageBody::ToClientResponseStart(resp) => { - info!(?resp.request_id, status = resp.status, "forwarding HTTP response to NATS"); + info!(?resp.request_id, status = resp.status, "forwarding HTTP response to pubsub"); let response_topic = TunnelHttpResponseSubject::new( runner_id, &port_name, @@ -211,10 +211,15 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ) .to_string(); - info!(?response_topic, ?resp.request_id, "publishing HTTP response to NATS"); + info!(%response_topic, ?resp.request_id, "publishing HTTP response to pubsub"); - if let Err(e) = - ups_clone.publish(&response_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &response_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -223,19 +228,19 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ); // Close client websocket with reason send_tunnel_closed_close_hyper( - &ws_write_ws_to_nats, + &ws_write_ws_to_pubsub, ) .await; break; } else { - error!(?err_any, ?resp.request_id, "failed to publish HTTP response to NATS"); + error!(?err_any, ?resp.request_id, "failed to publish HTTP response to pubsub"); } } else { - info!(?resp.request_id, "successfully published HTTP response to NATS"); + info!(?resp.request_id, "successfully published HTTP response to pubsub"); } } MessageBody::ToClientWebSocketMessage(ws_msg) => { - info!(?ws_msg.web_socket_id, "forwarding WebSocket message to NATS"); + info!(?ws_msg.web_socket_id, "forwarding WebSocket message to pubsub"); // Forward WebSocket messages to the topic that pegboard-gateway subscribes to let ws_topic = TunnelHttpWebSocketSubject::new( runner_id, @@ -244,10 +249,15 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ) .to_string(); - info!(?ws_topic, ?ws_msg.web_socket_id, "publishing WebSocket message to NATS"); + info!(%ws_topic, ?ws_msg.web_socket_id, "publishing WebSocket message to pubsub"); - if let Err(e) = - ups_clone.publish(&ws_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &ws_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -256,19 +266,19 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ); // Close client websocket with reason send_tunnel_closed_close_hyper( - &ws_write_ws_to_nats, + &ws_write_ws_to_pubsub, ) .await; break; } else { - error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to NATS"); + error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to pubsub"); } } else { - info!(?ws_msg.web_socket_id, "successfully published WebSocket message to NATS"); + info!(?ws_msg.web_socket_id, "successfully published WebSocket message to pubsub"); } } MessageBody::ToClientWebSocketOpen(ws_open) => { - info!(?ws_open.web_socket_id, "forwarding WebSocket open to NATS"); + info!(?ws_open.web_socket_id, "forwarding WebSocket open to pubsub"); let ws_topic = TunnelHttpWebSocketSubject::new( runner_id, &port_name, @@ -276,8 +286,13 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ) .to_string(); - if let Err(e) = - ups_clone.publish(&ws_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &ws_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -286,19 +301,19 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ); // Close client websocket with reason send_tunnel_closed_close_hyper( - &ws_write_ws_to_nats, + &ws_write_ws_to_pubsub, ) .await; break; } else { - error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to NATS"); + error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to pubsub"); } } else { - info!(?ws_open.web_socket_id, "successfully published WebSocket open to NATS"); + info!(?ws_open.web_socket_id, "successfully published WebSocket open to pubsub"); } } MessageBody::ToClientWebSocketClose(ws_close) => { - info!(?ws_close.web_socket_id, "forwarding WebSocket close to NATS"); + info!(?ws_close.web_socket_id, "forwarding WebSocket close to pubsub"); let ws_topic = TunnelHttpWebSocketSubject::new( runner_id, &port_name, @@ -306,8 +321,13 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ) .to_string(); - if let Err(e) = - ups_clone.publish(&ws_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &ws_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -316,21 +336,21 @@ impl CustomServeTrait for PegboardTunnelCustomServe { ); // Close client websocket with reason send_tunnel_closed_close_hyper( - &ws_write_ws_to_nats, + &ws_write_ws_to_pubsub, ) .await; break; } else { - error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to NATS"); + error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to pubsub"); } } else { - info!(?ws_close.web_socket_id, "successfully published WebSocket close to NATS"); + info!(?ws_close.web_socket_id, "successfully published WebSocket close to pubsub"); } } _ => { - // For other message types, we might not need to forward to NATS + // For other message types, we might not need to forward to pubsub info!( - "Received non-response message from WebSocket, skipping NATS forward" + "Received non-response message from WebSocket, skipping pubsub forward" ); continue; } @@ -354,7 +374,7 @@ impl CustomServeTrait for PegboardTunnelCustomServe { } } } - info!("WebSocket to NATS forwarding task ended"); + info!("WebSocket to pubsub forwarding task ended"); // Clean up connection connections_clone.write().await.remove(&connection_id); @@ -362,11 +382,11 @@ impl CustomServeTrait for PegboardTunnelCustomServe { // Wait for either task to complete tokio::select! { - _ = nats_to_ws => { - info!("NATS to WebSocket task completed"); + _ = pubsub_to_ws => { + info!("pubsub to WebSocket task completed"); } - _ = ws_to_nats => { - info!("WebSocket to NATS task completed"); + _ = ws_to_pubsub => { + info!("WebSocket to pubsub task completed"); } } @@ -440,9 +460,9 @@ async fn handle_connection( let connection_id = rivet_util::Id::nil(); - // Subscribe to NATS topic for this runner using raw NATS client + // Subscribe to pubsub topic for this runner using raw pubsub client let topic = TunnelHttpRunnerSubject::new(runner_id, &port_name).to_string(); - info!(?topic, ?runner_id, "subscribing to NATS topic"); + info!(%topic, ?runner_id, "subscribing to pubsub topic"); // Get UPS (UniversalPubSub) client let ups = ctx.pools().ups()?; @@ -468,16 +488,16 @@ async fn handle_connection( let connections_clone = connections.clone(); let ups_clone = ups.clone(); - // Task for forwarding NATS -> WebSocket - let nats_to_ws = tokio::spawn(async move { + // Task for forwarding pubsub -> WebSocket + let pubsub_to_ws = tokio::spawn(async move { while let ::std::result::Result::Ok(NextOutput::Message(msg)) = sub.next().await { // Ack message - //match msg.reply(&[]).await { - // Result::Ok(_) => {} - // Err(err) => { - // tracing::warn!(?err, "failed to ack gateway request response message") - // } - //}; + match msg.reply(&[]).await { + Result::Ok(_) => {} + Err(err) => { + tracing::warn!(?err, "failed to ack gateway request response message") + } + }; // Forward raw message to WebSocket let ws_msg = @@ -492,9 +512,9 @@ async fn handle_connection( } }); - // Task for forwarding WebSocket -> NATS - let ws_write_ws_to_nats = ws_write.clone(); - let ws_to_nats = tokio::spawn(async move { + // Task for forwarding WebSocket -> pubsub + let ws_write_ws_to_pubsub = ws_write.clone(); + let ws_to_pubsub = tokio::spawn(async move { while let Some(msg) = ws_read.next().await { match msg { ::std::result::Result::Ok(tokio_tungstenite::tungstenite::Message::Binary( @@ -513,8 +533,13 @@ async fn handle_connection( ) .to_string(); - if let Err(e) = - ups_clone.publish(&response_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &response_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -522,11 +547,11 @@ async fn handle_connection( "tunnel closed while publishing HTTP response; closing client websocket" ); // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_nats) + send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) .await; break; } else { - error!(?err_any, ?resp.request_id, "failed to publish HTTP response to NATS"); + error!(?err_any, ?resp.request_id, "failed to publish HTTP response to pubsub"); } } } @@ -538,8 +563,13 @@ async fn handle_connection( ) .to_string(); - if let Err(e) = - ups_clone.publish(&ws_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &ws_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -547,11 +577,11 @@ async fn handle_connection( "tunnel closed while publishing WebSocket message; closing client websocket" ); // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_nats) + send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) .await; break; } else { - error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to NATS"); + error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to pubsub"); } } } @@ -563,8 +593,13 @@ async fn handle_connection( ) .to_string(); - if let Err(e) = - ups_clone.publish(&ws_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &ws_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -572,11 +607,11 @@ async fn handle_connection( "tunnel closed while publishing WebSocket open; closing client websocket" ); // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_nats) + send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) .await; break; } else { - error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to NATS"); + error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to pubsub"); } } } @@ -588,8 +623,13 @@ async fn handle_connection( ) .to_string(); - if let Err(e) = - ups_clone.publish(&ws_topic, &data.to_vec()).await + if let Err(e) = ups_clone + .request_with_timeout( + &ws_topic, + &data.to_vec(), + UPS_REQ_TIMEOUT, + ) + .await { let err_any: anyhow::Error = e.into(); if is_tunnel_closed_error(&err_any) { @@ -597,18 +637,18 @@ async fn handle_connection( "tunnel closed while publishing WebSocket close; closing client websocket" ); // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_nats) + send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) .await; break; } else { - error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to NATS"); + error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to pubsub"); } } } _ => { - // For other message types, we might not need to forward to NATS + // For other message types, we might not need to forward to pubsub info!( - "Received non-response message from WebSocket, skipping NATS forward" + "Received non-response message from WebSocket, skipping pubsub forward" ); continue; } @@ -639,11 +679,11 @@ async fn handle_connection( // Wait for either task to complete tokio::select! { - _ = nats_to_ws => { - info!("NATS to WebSocket task completed"); + _ = pubsub_to_ws => { + info!("pubsub to WebSocket task completed"); } - _ = ws_to_nats => { - info!("WebSocket to NATS task completed"); + _ = ws_to_pubsub => { + info!("WebSocket to pubsub task completed"); } } diff --git a/packages/core/pegboard-tunnel/tests/integration.rs b/packages/core/pegboard-tunnel/tests/integration.rs index 71b534874d..2c20bfe638 100644 --- a/packages/core/pegboard-tunnel/tests/integration.rs +++ b/packages/core/pegboard-tunnel/tests/integration.rs @@ -50,14 +50,14 @@ async fn test_tunnel_bidirectional_forwarding() -> Result<()> { let runner_id = Id::nil(); let port_name = "default"; - // Give tunnel time to set up NATS subscription after WebSocket connection + // Give tunnel time to set up pubsub subscription after WebSocket connection sleep(Duration::from_secs(1)).await; - // Test 1: NATS to WebSocket forwarding - test_nats_to_websocket(&ups, &mut ws_stream, runner_id, port_name).await?; + // Test 1: pubsub to WebSocket forwarding + test_pubsub_to_websocket(&ups, &mut ws_stream, runner_id, port_name).await?; - // Test 2: WebSocket to NATS forwarding - test_websocket_to_nats(&ups, &mut ws_stream, runner_id).await?; + // Test 2: WebSocket to pubsub forwarding + test_websocket_to_pubsub(&ups, &mut ws_stream, runner_id).await?; // Clean up tunnel_handle.abort(); @@ -65,7 +65,7 @@ async fn test_tunnel_bidirectional_forwarding() -> Result<()> { Ok(()) } -async fn test_nats_to_websocket( +async fn test_pubsub_to_websocket( ups: &PubSub, ws_stream: &mut WebSocketStream>, runner_id: Id, @@ -93,7 +93,7 @@ async fn test_nats_to_websocket( // Serialize the message let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message))?; - // Publish to NATS topic using proper subject + // Publish to pubsub topic using proper subject let topic = TunnelHttpRunnerSubject::new(&runner_id.to_string(), port_name).to_string(); ups.request(&topic, &serialized).await?; @@ -111,7 +111,7 @@ async fn test_nats_to_websocket( assert_eq!(req.request_id, request_id); assert_eq!(req.method, "GET"); assert_eq!(req.path, "/test"); - println!("✓ NATS to WebSocket forwarding successful"); + println!("✓ pubsub to WebSocket forwarding successful"); } _ => bail!("Unexpected message type received"), } @@ -122,7 +122,7 @@ async fn test_nats_to_websocket( Ok(()) } -async fn test_websocket_to_nats( +async fn test_websocket_to_pubsub( ups: &PubSub, ws_stream: &mut WebSocketStream>, runner_id: Id, @@ -153,7 +153,7 @@ async fn test_websocket_to_nats( let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message))?; ws_stream.send(WsMessage::Binary(serialized.into())).await?; - // Wait for message on NATS + // Wait for message on pubsub let received = timeout(Duration::from_secs(10), subscriber.next()).await??; match received { @@ -164,12 +164,12 @@ async fn test_websocket_to_nats( MessageBody::ToClientResponseStart(resp) => { assert_eq!(resp.request_id, request_id); assert_eq!(resp.status, 200); - println!("✓ WebSocket to NATS forwarding successful"); + println!("✓ WebSocket to pubsub forwarding successful"); } _ => bail!("Unexpected message type received"), } } - _ => bail!("Expected message from NATS"), + _ => bail!("Expected message from subscriber"), } Ok(()) diff --git a/packages/services/pegboard/src/pubsub_subjects.rs b/packages/services/pegboard/src/pubsub_subjects.rs index c3f3e0296c..13d39b1704 100644 --- a/packages/services/pegboard/src/pubsub_subjects.rs +++ b/packages/services/pegboard/src/pubsub_subjects.rs @@ -18,7 +18,7 @@ impl std::fmt::Display for TunnelHttpRunnerSubject<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "pegboard.tunnel_http.runner.{}.{}", + "pegboard.tunnel.http.runner.{}.{}", self.runner_id, self.port_name ) }