diff --git a/Cargo.lock b/Cargo.lock index b6f660bc76..564b06bd1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4594,7 +4594,6 @@ dependencies = [ "anyhow", "bare_gen", "base64 0.22.1", - "gasoline", "indoc", "prettyplease", "rivet-util", @@ -4619,6 +4618,23 @@ dependencies = [ "utoipa", ] +[[package]] +name = "rivet-ups-protocol" +version = "0.0.1" +dependencies = [ + "anyhow", + "bare_gen", + "base64 0.22.1", + "indoc", + "prettyplease", + "rivet-util", + "serde", + "serde_bare", + "serde_json", + "syn 2.0.104", + "versioned-data-util", +] + [[package]] name = "rivet-util" version = "0.0.1" @@ -6318,6 +6334,7 @@ dependencies = [ "rivet-env", "rivet-error", "rivet-test-deps-docker", + "rivet-ups-protocol", "serde", "serde_json", "sha2", @@ -6327,6 +6344,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "versioned-data-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 911ec31c03..08114056cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] resolver = "2" -members = ["packages/common/api-builder","packages/common/api-client","packages/common/api-types","packages/common/api-util","packages/common/cache/build","packages/common/cache/result","packages/common/clickhouse-inserter","packages/common/clickhouse-user-query","packages/common/config","packages/common/env","packages/common/error/core","packages/common/error/macros","packages/common/gasoline/core","packages/common/gasoline/macros","packages/common/logs","packages/common/metrics","packages/common/pools","packages/common/runtime","packages/common/service-manager","packages/common/telemetry","packages/common/test-deps","packages/common/test-deps-docker","packages/common/types","packages/common/udb-util","packages/common/universaldb","packages/common/universalpubsub","packages/common/util/core","packages/common/util/id","packages/common/versioned-data-util","packages/core/actor-kv","packages/core/api-peer","packages/core/api-public","packages/core/bootstrap","packages/core/dump-openapi","packages/core/guard/core","packages/core/guard/server","packages/core/pegboard-gateway","packages/core/pegboard-runner-ws","packages/core/pegboard-tunnel","packages/core/workflow-worker","packages/infra/engine","packages/services/epoxy","packages/services/namespace","packages/services/pegboard","sdks/rust/api-full","sdks/rust/bare_gen","sdks/rust/epoxy-protocol","sdks/rust/key-data","sdks/rust/runner-protocol","sdks/rust/tunnel-protocol"] +members = ["packages/common/api-builder","packages/common/api-client","packages/common/api-types","packages/common/api-util","packages/common/cache/build","packages/common/cache/result","packages/common/clickhouse-inserter","packages/common/clickhouse-user-query","packages/common/config","packages/common/env","packages/common/error/core","packages/common/error/macros","packages/common/gasoline/core","packages/common/gasoline/macros","packages/common/logs","packages/common/metrics","packages/common/pools","packages/common/runtime","packages/common/service-manager","packages/common/telemetry","packages/common/test-deps","packages/common/test-deps-docker","packages/common/types","packages/common/udb-util","packages/common/universaldb","packages/common/universalpubsub","packages/common/util/core","packages/common/util/id","packages/common/versioned-data-util","packages/core/actor-kv","packages/core/api-peer","packages/core/api-public","packages/core/bootstrap","packages/core/dump-openapi","packages/core/guard/core","packages/core/guard/server","packages/core/pegboard-gateway","packages/core/pegboard-runner-ws","packages/core/pegboard-tunnel","packages/core/workflow-worker","packages/infra/engine","packages/services/epoxy","packages/services/namespace","packages/services/pegboard","sdks/rust/api-full","sdks/rust/bare_gen","sdks/rust/epoxy-protocol","sdks/rust/key-data","sdks/rust/runner-protocol","sdks/rust/tunnel-protocol","sdks/rust/ups-protocol"] [workspace.package] version = "0.0.1" @@ -398,6 +398,9 @@ path = "sdks/rust/runner-protocol" [workspace.dependencies.rivet-tunnel-protocol] path = "sdks/rust/tunnel-protocol" +[workspace.dependencies.rivet-ups-protocol] +path = "sdks/rust/ups-protocol" + [profile.dev] overflow-checks = false debug = false diff --git a/packages/common/gasoline/core/src/ctx/message.rs b/packages/common/gasoline/core/src/ctx/message.rs index d8228c100d..6af2b3e50e 100644 --- a/packages/common/gasoline/core/src/ctx/message.rs +++ b/packages/common/gasoline/core/src/ctx/message.rs @@ -153,7 +153,15 @@ impl MessageCtx { message_len = message_buf.len(), "publishing message to pubsub" ); - if let Err(err) = self.pubsub.publish(&subject, &(*message_buf)).await { + if let Err(err) = self + .pubsub + .publish( + &subject, + &(*message_buf), + universalpubsub::PublishOpts::broadcast(), + ) + .await + { tracing::warn!(?err, "publish message failed, trying again"); continue; } diff --git a/packages/common/gasoline/core/src/db/kv/mod.rs b/packages/common/gasoline/core/src/db/kv/mod.rs index 247b538e1b..0321d3453c 100644 --- a/packages/common/gasoline/core/src/db/kv/mod.rs +++ b/packages/common/gasoline/core/src/db/kv/mod.rs @@ -62,7 +62,14 @@ impl DatabaseKv { let spawn_res = tokio::task::Builder::new().name("wake").spawn( async move { // Fail gracefully - if let Err(err) = pubsub.publish(WORKER_WAKE_SUBJECT, &Vec::new()).await { + if let Err(err) = pubsub + .publish( + WORKER_WAKE_SUBJECT, + &Vec::new(), + universalpubsub::PublishOpts::broadcast(), + ) + .await + { tracing::warn!(?err, "failed to publish wake message"); } } diff --git a/packages/common/universalpubsub/Cargo.toml b/packages/common/universalpubsub/Cargo.toml index 645a3e1785..cd36f035b5 100644 --- a/packages/common/universalpubsub/Cargo.toml +++ b/packages/common/universalpubsub/Cargo.toml @@ -14,7 +14,9 @@ deadpool-postgres.workspace = true futures-util.workspace = true moka.workspace = true rivet-error.workspace = true +rivet-ups-protocol.workspace = true serde_json.workspace = true +versioned-data-util.workspace = true serde.workspace = true sha2.workspace = true tokio-postgres.workspace = true diff --git a/packages/common/universalpubsub/src/chunking.rs b/packages/common/universalpubsub/src/chunking.rs new file mode 100644 index 0000000000..8bf846b7a2 --- /dev/null +++ b/packages/common/universalpubsub/src/chunking.rs @@ -0,0 +1,226 @@ +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use anyhow::*; +use rivet_ups_protocol::versioned::UpsMessage; +use rivet_ups_protocol::{MessageBody, MessageChunk, MessageStart, PROTOCOL_VERSION}; +use versioned_data_util::OwnedVersionedData; + +pub const CHUNK_BUFFER_GC_INTERVAL: Duration = Duration::from_secs(60); +pub const CHUNK_BUFFER_MAX_AGE: Duration = Duration::from_secs(300); + +#[derive(Debug)] +pub struct ChunkBuffer { + pub message_id: [u8; 16], + pub received_chunks: u32, + pub last_chunk_ts: Instant, + pub buffer: Vec, + pub chunk_count: u32, + pub reply_subject: Option, +} + +pub struct ChunkTracker { + chunks_in_process: HashMap<[u8; 16], ChunkBuffer>, +} + +impl ChunkTracker { + pub fn new() -> Self { + Self { + chunks_in_process: HashMap::new(), + } + } + + pub fn process_chunk( + &mut self, + raw_message: &[u8], + ) -> Result, Option)>> { + let message = UpsMessage::deserialize_with_embedded_version(raw_message)?; + + match message.body { + MessageBody::MessageStart(msg) => { + // If only one chunk, return immediately + if msg.chunk_count == 1 { + return Ok(Some((msg.payload, msg.reply_subject))); + } + + // Start of a multi-chunk message + let buffer = ChunkBuffer { + message_id: msg.message_id, + received_chunks: 1, + last_chunk_ts: Instant::now(), + buffer: msg.payload, + chunk_count: msg.chunk_count, + reply_subject: msg.reply_subject, + }; + self.chunks_in_process.insert(msg.message_id, buffer); + Ok(None) + } + MessageBody::MessageChunk(msg) => { + // Find the matching buffer using message_id + let buffer = self.chunks_in_process.get_mut(&msg.message_id); + + let Some(buffer) = buffer else { + bail!( + "received chunk {} for message {:?} but no matching buffer found", + msg.chunk_index, + msg.message_id + ); + }; + + // Validate chunk order + if buffer.received_chunks != msg.chunk_index { + bail!( + "received chunk {} but expected chunk {} for message {:?}", + msg.chunk_index, + buffer.received_chunks, + msg.message_id + ); + } + + // Update buffer + buffer.buffer.extend_from_slice(&msg.payload); + buffer.received_chunks += 1; + buffer.last_chunk_ts = Instant::now(); + let is_complete = buffer.received_chunks == buffer.chunk_count; + + if is_complete { + let completed_buffer = self.chunks_in_process.remove(&msg.message_id).unwrap(); + Ok(Some(( + completed_buffer.buffer, + completed_buffer.reply_subject, + ))) + } else { + Ok(None) + } + } + } + } + + pub fn gc(&mut self) { + let now = Instant::now(); + let size_before = self.chunks_in_process.len(); + self.chunks_in_process + .retain(|_, buffer| now.duration_since(buffer.last_chunk_ts) < CHUNK_BUFFER_MAX_AGE); + let size_after = self.chunks_in_process.len(); + + tracing::debug!( + ?size_before, + ?size_after, + "performed chunk buffer garbage collection" + ); + } +} + +/// Splits a payload into chunks that fit within message size limits. +/// +/// This function handles chunking by accounting for different overhead +/// between the first chunk (MessageStart) and subsequent chunks (MessageChunk). +/// +/// The first chunk carries additional metadata like the reply_subject and chunk_count, +/// which means it has more protocol overhead and less room for payload data. +/// Subsequent chunks only carry a chunk_index, allowing them to fit more payload. +/// +/// This optimization ensures: +/// - Reply subject is only transmitted once (in MessageStart) +/// - Maximum payload utilization in each chunk +/// - Efficient bandwidth usage for multi-chunk messages +/// +/// # Returns +/// A vector of payload chunks, where each chunk is sized to fit within the message limit +/// after accounting for protocol overhead. +pub fn split_payload_into_chunks( + payload: &[u8], + max_message_size: usize, + message_id: [u8; 16], + reply_subject: Option<&str>, +) -> Result>> { + // Calculate overhead for MessageStart (first chunk) + let start_message = MessageStart { + message_id, + chunk_count: 1, + reply_subject: reply_subject.map(|s| s.to_string()), + payload: vec![], + }; + let start_ups_message = rivet_ups_protocol::UpsMessage { + body: MessageBody::MessageStart(start_message), + }; + let start_overhead = UpsMessage::latest(start_ups_message) + .serialize_with_embedded_version(PROTOCOL_VERSION)? + .len(); + + // Calculate overhead for MessageChunk (subsequent chunks) + let chunk_message = MessageChunk { + message_id, + chunk_index: 0, + payload: vec![], + }; + let chunk_ups_message = rivet_ups_protocol::UpsMessage { + body: MessageBody::MessageChunk(chunk_message), + }; + let chunk_overhead = UpsMessage::latest(chunk_ups_message) + .serialize_with_embedded_version(PROTOCOL_VERSION)? + .len(); + + // Calculate max payload sizes + let first_chunk_max_payload = max_message_size.saturating_sub(start_overhead); + let other_chunk_max_payload = max_message_size.saturating_sub(chunk_overhead); + + if first_chunk_max_payload == 0 || other_chunk_max_payload == 0 { + bail!("message overhead exceeds max message size"); + } + + // Calculate how many chunks we need + if payload.len() <= first_chunk_max_payload { + // Single chunk - all data fits in first message + return Ok(vec![payload.to_vec()]); + } + + // Multi-chunk: first chunk + remaining chunks + let remaining_after_first = payload.len() - first_chunk_max_payload; + let additional_chunks = + (remaining_after_first + other_chunk_max_payload - 1) / other_chunk_max_payload; + + let mut chunks = Vec::new(); + + // First chunk (smaller due to reply_subject overhead) + chunks.push(payload[..first_chunk_max_payload].to_vec()); + + // Subsequent chunks + let mut offset = first_chunk_max_payload; + for _ in 0..additional_chunks { + let end = std::cmp::min(offset + other_chunk_max_payload, payload.len()); + chunks.push(payload[offset..end].to_vec()); + offset = end; + } + + Ok(chunks) +} + +/// Encodes a chunk to the resulting BARE message. +pub fn encode_chunk( + payload: Vec, + chunk_idx: u32, + chunk_count: u32, + message_id: [u8; 16], + reply_subject: Option, +) -> Result> { + let body = if chunk_idx == 0 { + // First chunk - MessageStart + MessageBody::MessageStart(MessageStart { + message_id, + chunk_count, + reply_subject, + payload, + }) + } else { + // Subsequent chunks - MessageChunk + MessageBody::MessageChunk(MessageChunk { + message_id, + chunk_index: chunk_idx, + payload, + }) + }; + + let ups_message = rivet_ups_protocol::UpsMessage { body }; + UpsMessage::latest(ups_message).serialize_with_embedded_version(PROTOCOL_VERSION) +} diff --git a/packages/common/universalpubsub/src/driver/memory/mod.rs b/packages/common/universalpubsub/src/driver/memory/mod.rs index 3e7ee8ca20..507fe5cf69 100644 --- a/packages/common/universalpubsub/src/driver/memory/mod.rs +++ b/packages/common/universalpubsub/src/driver/memory/mod.rs @@ -1,28 +1,22 @@ use std::collections::HashMap; use std::sync::Arc; -use std::time::Duration; use anyhow::*; use async_trait::async_trait; -use tokio::sync::{mpsc, RwLock}; -use uuid::Uuid; +use tokio::sync::{RwLock, mpsc}; use crate::driver::{PubSubDriver, SubscriberDriver, SubscriberDriverHandle}; -use crate::pubsub::{Message, NextOutput, Response}; +use crate::pubsub::DriverOutput; -type Subscribers = Arc>>>>; +type Subscribers = Arc>>>>>; -#[derive(Clone, Debug)] -struct MemoryMessage { - payload: Vec, - reply_to: Option, -} +/// This is arbitrary. +const MEMORY_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024; // 10MiB #[derive(Clone)] pub struct MemoryDriver { channel: String, subscribers: Subscribers, - // No pending requests tracking needed for simple in-memory behavior } impl MemoryDriver { @@ -51,22 +45,18 @@ impl PubSubDriver for MemoryDriver { .push(tx); Ok(Box::new(MemorySubscriber { - driver: self.clone(), + subject: subject_with_channel, rx, })) } - async fn publish(&self, subject: &str, message: &[u8]) -> Result<()> { + async fn publish(&self, subject: &str, payload: &[u8]) -> Result<()> { let subject_with_channel = self.subject_with_channel(subject); let subscribers = self.subscribers.read().await; if let Some(subs) = subscribers.get(&subject_with_channel) { - let msg = MemoryMessage { - payload: message.to_vec(), - reply_to: None, - }; for tx in subs { - let _ = tx.send(msg.clone()); + let _ = tx.send(payload.to_vec()); } } @@ -77,124 +67,25 @@ impl PubSubDriver for MemoryDriver { Ok(()) } - async fn request( - &self, - subject: &str, - payload: &[u8], - timeout: Option, - ) -> Result { - let subject_with_channel = self.subject_with_channel(subject); - - // Check if there are any subscribers for this subject - { - let subscribers = self.subscribers.read().await; - if !subscribers.contains_key(&subject_with_channel) - || subscribers - .get(&subject_with_channel) - .map_or(true, |subs| subs.is_empty()) - { - // HACK: There is no native NoResponders error, so we return - // RequestTimeout. This is equivalent since the request would time out - // if there are no responders. - return Err(crate::errors::Ups::RequestTimeout.build().into()); - } - } - - // Create a unique reply subject for this request - let reply_subject = format!("_INBOX.{}", Uuid::new_v4()); - let reply_subject_with_channel = self.subject_with_channel(&reply_subject); - - // Create a oneshot channel for the response - let (tx, rx) = tokio::sync::oneshot::channel(); - - // Subscribe to the reply subject to receive the response - { - let mut subscribers = self.subscribers.write().await; - let (reply_tx, mut reply_rx) = mpsc::unbounded_channel(); - subscribers - .entry(reply_subject_with_channel.clone()) - .or_default() - .push(reply_tx); - - // Spawn a task to wait for the reply - tokio::spawn(async move { - if let Some(msg) = reply_rx.recv().await { - let _ = tx.send(Response { - payload: msg.payload, - }); - } - }); - } - - // Send the request message with reply_to field to first subscriber (if any) - { - let subscribers = self.subscribers.read().await; - if let Some(subs) = subscribers.get(&subject_with_channel) { - if let Some(tx) = subs.first() { - let msg = MemoryMessage { - payload: payload.to_vec(), - reply_to: Some(reply_subject.clone()), - }; - let _ = tx.send(msg); - } - } - } - - // Wait for response with optional timeout - let response = if let Some(timeout_duration) = timeout { - match tokio::time::timeout(timeout_duration, rx).await { - std::result::Result::Ok(result) => match result { - std::result::Result::Ok(response) => response, - std::result::Result::Err(_) => { - // Channel error - shouldn't happen in normal operation - return Err(crate::errors::Ups::RequestTimeout.build().into()); - } - }, - std::result::Result::Err(_) => { - // Timeout elapsed - // Clean up the reply subscription - let subscribers = self.subscribers.clone(); - let reply_subject = reply_subject_with_channel.clone(); - tokio::spawn(async move { - let mut subs = subscribers.write().await; - subs.remove(&reply_subject); - }); - return Err(crate::errors::Ups::RequestTimeout.build().into()); - } - } - } else { - match rx.await { - std::result::Result::Ok(response) => response, - std::result::Result::Err(_) => { - // Channel closed without response - shouldn't happen - return Err(anyhow!("Request failed: no response received")); - } - } - }; - - Ok(response) - } - - async fn send_request_reply(&self, reply: &str, payload: &[u8]) -> Result<()> { - self.publish(reply, payload).await + fn max_message_size(&self) -> usize { + MEMORY_MAX_MESSAGE_SIZE } } pub struct MemorySubscriber { - driver: MemoryDriver, - rx: mpsc::UnboundedReceiver, + subject: String, + rx: mpsc::UnboundedReceiver>, } #[async_trait] impl SubscriberDriver for MemorySubscriber { - async fn next(&mut self) -> Result { + async fn next(&mut self) -> Result { match self.rx.recv().await { - Some(msg) => Ok(NextOutput::Message(Message { - driver: Arc::new(self.driver.clone()), - payload: msg.payload, - reply: msg.reply_to, - })), - None => Ok(NextOutput::Unsubscribed), + Some(payload) => Ok(DriverOutput::Message { + subject: self.subject.clone(), + payload, + }), + None => Ok(DriverOutput::Unsubscribed), } } } diff --git a/packages/common/universalpubsub/src/driver/mod.rs b/packages/common/universalpubsub/src/driver/mod.rs index fce5a24678..a84f8eeede 100644 --- a/packages/common/universalpubsub/src/driver/mod.rs +++ b/packages/common/universalpubsub/src/driver/mod.rs @@ -1,5 +1,4 @@ use std::sync::Arc; -use std::time::Duration; use anyhow::*; use async_trait::async_trait; @@ -8,27 +7,50 @@ pub mod memory; pub mod nats; pub mod postgres; -use crate::pubsub::{NextOutput, Response}; - pub type PubSubDriverHandle = Arc; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum PublishBehavior { + /// Publishes a message to a single subscriber. + /// + /// This does not limit it to a single subscriber, but instead enables in-memory optimizations + /// to speed this up. + OneSubscriber, + + /// Publishes a message to multiple subscribers. + Broadcast, +} + +#[derive(Clone, Copy, Debug)] +pub struct PublishOpts { + pub behavior: PublishBehavior, +} + +impl PublishOpts { + pub const fn one() -> Self { + Self { + behavior: PublishBehavior::OneSubscriber, + } + } + + pub const fn broadcast() -> Self { + Self { + behavior: PublishBehavior::Broadcast, + } + } +} + #[async_trait] pub trait PubSubDriver: Send + Sync { async fn subscribe(&self, subject: &str) -> Result>; async fn publish(&self, subject: &str, message: &[u8]) -> Result<()>; - async fn request( - &self, - subject: &str, - payload: &[u8], - timeout: Option, - ) -> Result; - async fn send_request_reply(&self, reply: &str, payload: &[u8]) -> Result<()>; async fn flush(&self) -> Result<()>; + fn max_message_size(&self) -> usize; } pub type SubscriberDriverHandle = Box; #[async_trait] pub trait SubscriberDriver: Send + Sync { - async fn next(&mut self) -> Result; + async fn next(&mut self) -> Result; } diff --git a/packages/common/universalpubsub/src/driver/nats/mod.rs b/packages/common/universalpubsub/src/driver/nats/mod.rs index 223e50d7ae..a472331c61 100644 --- a/packages/common/universalpubsub/src/driver/nats/mod.rs +++ b/packages/common/universalpubsub/src/driver/nats/mod.rs @@ -1,14 +1,17 @@ -use std::sync::Arc; -use std::time::Duration; - use anyhow::*; -use async_nats::{Client, client::RequestErrorKind}; +use async_nats::Client; use async_trait::async_trait; use futures_util::StreamExt; use crate::driver::{PubSubDriver, SubscriberDriver, SubscriberDriverHandle}; -use crate::errors; -use crate::pubsub::{Message, NextOutput, Response}; +use crate::pubsub::DriverOutput; + +/// > The size is set to 1 MB by default, but can be increased up to 64 MB if needed (though we recommend keeping the max message size to something more reasonable like 8 MB). +/// +/// https://docs.nats.io/reference/faq#is-there-a-message-size-limitation-in-nats +/// +/// When they say "MB" they mean "MiB." Ignorance strikes again. +pub const NATS_MAX_MESSAGE_SIZE: usize = 1024 * 1024; #[derive(Clone)] pub struct NatsDriver { @@ -20,11 +23,10 @@ impl NatsDriver { options: async_nats::ConnectOptions, server_addrs: impl async_nats::ToServerAddrs, ) -> Result { - // NOTE: async-nats adds ConnectionInfo.no_responders by default - tracing::debug!("nats connecting"); let client = options.connect(server_addrs).await?; tracing::debug!("nats connected"); + Ok(Self { client }) } } @@ -33,15 +35,12 @@ impl NatsDriver { impl PubSubDriver for NatsDriver { async fn subscribe(&self, subject: &str) -> Result { let subscriber = self.client.subscribe(subject.to_string()).await?; - Ok(Box::new(NatsSubscriber { - driver: self.clone(), - subscriber, - })) + Ok(Box::new(NatsSubscriber { subscriber })) } - async fn publish(&self, subject: &str, message: &[u8]) -> Result<()> { + async fn publish(&self, subject: &str, payload: &[u8]) -> Result<()> { self.client - .publish(subject.to_string(), message.to_vec().into()) + .publish(subject.to_string(), payload.to_vec().into()) .await?; Ok(()) } @@ -51,78 +50,24 @@ impl PubSubDriver for NatsDriver { Ok(()) } - async fn request( - &self, - subject: &str, - payload: &[u8], - timeout: Option, - ) -> Result { - let request_future = self - .client - .request(subject.to_string(), payload.to_vec().into()); - - let request = if let Some(timeout) = timeout { - match tokio::time::timeout(timeout, request_future).await { - std::result::Result::Ok(result) => match result { - std::result::Result::Ok(msg) => msg, - std::result::Result::Err(err) => match err.kind() { - RequestErrorKind::NoResponders => { - // HACK: There is no native NoResponders error, so we return - // RequestTimeout. This is equivalent since the request would time out - // if there are no responders. - return Err(errors::Ups::RequestTimeout.build().into()); - } - RequestErrorKind::TimedOut => { - return Err(errors::Ups::RequestTimeout.build().into()); - } - _ => bail!(err), - }, - }, - std::result::Result::Err(_) => { - return Err(errors::Ups::RequestTimeout.build().into()); - } - } - } else { - match request_future.await { - std::result::Result::Ok(msg) => msg, - std::result::Result::Err(err) => match err.kind() { - RequestErrorKind::NoResponders => { - // HACK: See above - return Err(errors::Ups::RequestTimeout.build().into()); - } - RequestErrorKind::TimedOut => { - return Err(errors::Ups::RequestTimeout.build().into()); - } - _ => bail!(err), - }, - } - }; - - Ok(Response { - payload: request.payload.to_vec(), - }) - } - - async fn send_request_reply(&self, reply: &str, payload: &[u8]) -> Result<()> { - self.publish(reply, payload).await + fn max_message_size(&self) -> usize { + NATS_MAX_MESSAGE_SIZE } } pub struct NatsSubscriber { - driver: NatsDriver, subscriber: async_nats::Subscriber, } #[async_trait] impl SubscriberDriver for NatsSubscriber { - async fn next(&mut self) -> Result { + async fn next(&mut self) -> Result { match self.subscriber.next().await { - Some(msg) => Ok(NextOutput::Message(Message { - driver: Arc::new(self.driver.clone()), + Some(msg) => Ok(DriverOutput::Message { + subject: msg.subject.to_string(), payload: msg.payload.to_vec(), - reply: msg.reply.map(|r| r.to_string()), - })), - None => Ok(NextOutput::Unsubscribed), + }), + None => Ok(DriverOutput::Unsubscribed), } } } diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index 10af3cacaa..38959eb33a 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -1,62 +1,46 @@ -use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use std::time::Duration; use anyhow::*; use async_trait::async_trait; use base64::Engine; -use base64::engine::general_purpose::STANDARD as BASE64; +use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; use moka::future::Cache; -use serde::{Deserialize, Serialize}; -use tokio::sync::RwLock; use tokio_postgres::{AsyncMessage, NoTls}; use tracing::Instrument; use crate::driver::{PubSubDriver, SubscriberDriver, SubscriberDriverHandle}; -use crate::errors; -use crate::pubsub::{Message, NextOutput, Response}; +use crate::pubsub::DriverOutput; #[derive(Clone)] struct Subscription { - // Channel to send requests to this subscription - tx: tokio::sync::broadcast::Sender<(Vec, Option)>, + // Channel to send messages to this subscription + tx: tokio::sync::broadcast::Sender>, } -// Represents a local subscription that can handle request/response -struct LocalSubscription { - // Channel to send requests to this subscription - tx: tokio::sync::broadcast::Sender, -} +/// > In the default configuration it must be shorter than 8000 bytes +/// +/// https://www.postgresql.org/docs/17/sql-notify.html +const MAX_NOTIFY_LENGTH: usize = 8000; -// Request sent to a local subscription -#[derive(Clone)] -struct LocalRequest { - payload: Vec, - reply_tx: tokio::sync::mpsc::Sender>, -} +/// Base64 encoding ratio +const BYTES_PER_BLOCK: usize = 3; +const CHARS_PER_BLOCK: usize = 4; + +/// Calculate max message size if encoded as base64 +/// +/// We need to remove BYTES_PER_BLOCK since there might be a tail on the base64-encoded data that +/// would bump it over the limit. +pub const POSTGRES_MAX_MESSAGE_SIZE: usize = + (MAX_NOTIFY_LENGTH * BYTES_PER_BLOCK) / CHARS_PER_BLOCK - BYTES_PER_BLOCK; #[derive(Clone)] pub struct PostgresDriver { - memory_optimization: bool, pool: Arc, client: Arc, - subscriptions: Cache, - - // Maps subject to local subscription on this node for fast path - local_subscriptions: Arc>>, -} - -#[derive(Serialize, Deserialize)] -struct Envelope { - // Base64-encoded payload - #[serde(rename = "p")] - payload: String, - #[serde(rename = "r", skip_serializing_if = "Option::is_none")] - reply_subject: Option, } impl PostgresDriver { @@ -92,467 +76,124 @@ impl PostgresDriver { match poll_fn(|cx| conn.poll_message(cx)).await { Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { if let Some(sub) = subscriptions2.get(note.channel()).await { - let env = match serde_json::from_str::(¬e.payload()) { - std::result::Result::Ok(env) => env, + let bytes = match BASE64.decode(note.payload()) { + std::result::Result::Ok(b) => b, std::result::Result::Err(err) => { - tracing::error!(?err, "failed deserializing envelope"); + tracing::error!(?err, "failed decoding base64"); break; } }; - let payload = match BASE64 - .decode(env.payload) - .context("invalid base64 payload") - { - std::result::Result::Ok(p) => p, - std::result::Result::Err(err) => { - tracing::error!(?err, "failed deserializing envelope"); - break; - } - }; - - let _ = sub.tx.send((payload, env.reply_subject)); + let _ = sub.tx.send(bytes); } } - Some(std::result::Result::Ok(_)) => continue, + Some(std::result::Result::Ok(_)) => { + // Ignore other async messages + } Some(std::result::Result::Err(err)) => { - tracing::error!(?err, "ups poll loop failed"); + tracing::error!(?err, "async postgres error"); + break; + } + None => { + tracing::debug!("async postgres connection closed"); break; } - None => break, } } - - tracing::info!("ups poll loop stopped"); + tracing::debug!("listen connection closed"); }); Ok(Self { - memory_optimization, pool: Arc::new(pool), client: Arc::new(client), subscriptions, - local_subscriptions: Arc::new(RwLock::new(HashMap::new())), }) } + + fn hash_subject(&self, subject: &str) -> String { + // Postgres channel names have a 64 character limit + // Hash the subject to ensure it fits + let mut hasher = DefaultHasher::new(); + subject.hash(&mut hasher); + format!("ups_{:x}", hasher.finish()) + } } #[async_trait] impl PubSubDriver for PostgresDriver { - #[tracing::instrument(skip(self), fields(subject))] async fn subscribe(&self, subject: &str) -> Result { - tracing::debug!(%subject, "starting subscription"); + let hashed = self.hash_subject(subject); - // Set up local request handling channel if memory optimization is enabled - let local_request_rx = if self.memory_optimization { - // Register this subscription in the local map - tracing::debug!( - %subject, - "registering local subscription for memory optimization" - ); - let mut subs = self.local_subscriptions.write().await; - let local_rx = subs - .entry(subject.to_string()) - .or_insert_with(|| LocalSubscription { - tx: tokio::sync::broadcast::channel(64).0, - }) - .tx - .subscribe(); - tracing::debug!( - %subject, - "local subscription registered" - ); - - Some(local_rx) + // Check if we already have a subscription for this channel + let rx = if let Some(existing_sub) = self.subscriptions.get(&hashed).await { + // Reuse the existing broadcast channel + existing_sub.tx.subscribe() } else { - None + // Create a new broadcast channel for this subject + let (tx, rx) = tokio::sync::broadcast::channel(1024); + let subscription = Subscription { tx: tx.clone() }; + + // Register subscription + self.subscriptions + .insert(hashed.clone(), subscription) + .await; + + // Execute LISTEN command on the async client (for receiving notifications) + // This only needs to be done once per channel + let span = tracing::trace_span!("pg_listen"); + self.client + .execute(&format!("LISTEN \"{hashed}\""), &[]) + .instrument(span) + .await?; + + rx }; - // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes - let mut hasher = DefaultHasher::new(); - subject.hash(&mut hasher); - let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); - - let rx = self - .subscriptions - .entry(subject_hash.clone()) - .or_insert_with(async { - Subscription { - tx: tokio::sync::broadcast::channel(128).0, - } - }) - .await - .value() - .tx - .subscribe(); - - let sql = format!("LISTEN {}", quote_ident(&subject_hash)); - self.client.batch_execute(&sql).await?; - - tracing::debug!(%subject, "subscription established successfully"); Ok(Box::new(PostgresSubscriber { - driver: self.clone(), - rx, - local_request_rx, subject: subject.to_string(), + rx, })) } - #[tracing::instrument(skip(self, message), fields(subject, message_len = message.len()))] - async fn publish(&self, subject: &str, message: &[u8]) -> Result<()> { - tracing::debug!(%subject, message_len = message.len(), "publishing message"); - // Get a connection from the pool - let conn = self - .pool - .get() - .await - .context("failed to get connection from pool")?; - - // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes - let mut hasher = DefaultHasher::new(); - subject.hash(&mut hasher); - let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); - - // Encode payload - let env = Envelope { - payload: BASE64.encode(message), - reply_subject: None, - }; - let payload = serde_json::to_string(&env)?; - - // NOTIFY doesn't support parameterized queries, so we need to escape the payload - // Replace single quotes with two single quotes for SQL escaping - let escaped_payload = payload.replace('\'', "''"); - let sql = format!( - "NOTIFY {}, '{}'", - quote_ident(&subject_hash), - escaped_payload - ); - conn.batch_execute(&sql) - .instrument(tracing::debug_span!("notify_execute", %subject)) + async fn publish(&self, subject: &str, payload: &[u8]) -> Result<()> { + // Encode payload to base64 and send NOTIFY + let encoded = BASE64.encode(payload); + let conn = self.pool.get().await?; + let hashed = self.hash_subject(subject); + let span = tracing::trace_span!("pg_notify"); + conn.execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[]) + .instrument(span) .await?; - tracing::debug!(%subject, "message published successfully"); - Ok(()) - } - async fn flush(&self) -> Result<()> { - // No-op for Postgres Ok(()) } - #[tracing::instrument(skip(self, payload), fields(subject, payload_len = payload.len(), ?timeout))] - async fn request( - &self, - subject: &str, - payload: &[u8], - timeout: Option, - ) -> Result { - tracing::debug!( - %subject, - payload_len = payload.len(), - ?timeout, - "starting request" - ); - - // Memory fast path: check if we have local subscribers first - if self.memory_optimization { - let subs = self.local_subscriptions.read().await; - if let Some(local_sub) = subs.get(subject) { - tracing::debug!( - %subject, - "using memory fast path for request" - ); - - // Create a channel for the reply - let (reply_tx, mut reply_rx) = tokio::sync::mpsc::channel(1); - - // Send the request to the local subscription - let request = LocalRequest { - payload: payload.to_vec(), - reply_tx, - }; - - // Try to send the request - if local_sub.tx.send(request).is_ok() { - // Drop early to clear lock - drop(subs); - - // Wait for response with optional timeout - let response_future = async { - match reply_rx.recv().await { - Some(response_payload) => Ok(Response { - payload: response_payload, - }), - None => Err(anyhow!("local subscription closed")), - } - }; - - // Apply timeout if specified - if let Some(dur) = timeout { - return match tokio::time::timeout(dur, response_future).await { - std::result::Result::Ok(resp) => resp, - std::result::Result::Err(_) => { - Err(errors::Ups::RequestTimeout.build().into()) - } - }; - } else { - return response_future.await; - } - } - // If send failed, the subscription might be dead, clean it up later - // and fall through to normal path - } - } - - // Normal path: use database for request/response - tracing::debug!(%subject, "using database path for request"); - - // Create a temporary reply subject and a dedicated listener connection - let reply_subject = format!("_INBOX.{}", uuid::Uuid::new_v4()); - - let mut reply_sub = self.subscribe(&reply_subject).await?; - - // Get another connection from pool to publish the request - let publish_conn = self - .pool - .get() - .await - .context("failed to get connection from pool")?; - - // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes - let mut hasher = DefaultHasher::new(); - subject.hash(&mut hasher); - let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); - let mut hasher = DefaultHasher::new(); - reply_subject.hash(&mut hasher); - let reply_subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); - - // Publish request with reply subject encoded - let env = Envelope { - payload: BASE64.encode(payload), - reply_subject: Some(reply_subject_hash.clone()), - }; - let env_payload = serde_json::to_string(&env)?; - - // NOTIFY doesn't support parameterized queries - let escaped_payload = env_payload.replace('\'', "''"); - - let notify_sql = format!( - "NOTIFY {}, '{}'", - quote_ident(&subject_hash), - escaped_payload - ); - publish_conn.batch_execute(¬ify_sql).await?; - - // Wait for response with optional timeout - let response_future = async { - Ok(Response { - payload: match reply_sub.next().await? { - NextOutput::Message(msg) => msg.payload, - NextOutput::Unsubscribed => bail!("reply subscription unsubscribed"), - }, - }) - }; - - // Apply timeout if specified - if let Some(dur) = timeout { - match tokio::time::timeout(dur, response_future).await { - std::result::Result::Ok(resp) => resp, - std::result::Result::Err(_) => Err(errors::Ups::RequestTimeout.build().into()), - } - } else { - response_future.await - } - } - - // NOTE: The reply argument here is already a base64 encoded hash - async fn send_request_reply(&self, reply: &str, payload: &[u8]) -> Result<()> { - // Get a connection from the pool - let conn = self - .pool - .get() - .await - .context("failed to get connection from pool")?; - - // Publish reply without nested reply - let env = Envelope { - payload: BASE64.encode(payload), - reply_subject: None, - }; - let payload = serde_json::to_string(&env)?; - // NOTIFY doesn't support parameterized queries - let escaped_payload = payload.replace('\'', "''"); - let sql = format!("NOTIFY {}, '{}'", quote_ident(reply), escaped_payload); - conn.batch_execute(&sql).await?; - Ok(()) - } -} - -// Special driver for handling local replies -struct LocalReplyDriver { - reply_tx: tokio::sync::mpsc::Sender>, -} - -#[async_trait] -impl PubSubDriver for LocalReplyDriver { - async fn subscribe(&self, _subject: &str) -> Result { - Err(anyhow!("LocalReplyDriver does not support subscribe")) - } - - async fn publish(&self, _subject: &str, _message: &[u8]) -> Result<()> { - Err(anyhow!("LocalReplyDriver does not support publish")) - } - async fn flush(&self) -> Result<()> { Ok(()) } - async fn request( - &self, - _subject: &str, - _payload: &[u8], - _timeout: Option, - ) -> Result { - Err(anyhow!("LocalReplyDriver does not support request")) - } - - async fn send_request_reply(&self, _reply: &str, payload: &[u8]) -> Result<()> { - tracing::debug!("sending local request reply"); - - // Send the reply through the local channel - let _ = self.reply_tx.send(payload.to_vec()).await; - - Ok(()) + fn max_message_size(&self) -> usize { + POSTGRES_MAX_MESSAGE_SIZE } } pub struct PostgresSubscriber { - driver: PostgresDriver, - rx: tokio::sync::broadcast::Receiver<(Vec, Option)>, - local_request_rx: Option>, subject: String, + rx: tokio::sync::broadcast::Receiver>, } #[async_trait] impl SubscriberDriver for PostgresSubscriber { - #[tracing::instrument(skip(self), fields(subject = %self.subject))] - async fn next(&mut self) -> Result { - tracing::debug!("waiting for message"); - - // If we have a local request receiver, poll both channels - if let Some(ref mut local_rx) = self.local_request_rx { - tokio::select! { - // Check for local requests (memory fast path) - local_req = local_rx.recv() => { - match local_req { - std::result::Result::Ok(req) => { - // Create a synthetic reply subject for local request - let reply_subject = format!("_LOCAL.{}", uuid::Uuid::new_v4()); - - // Create a wrapper driver that will handle the reply - let local_driver = LocalReplyDriver { - reply_tx: req.reply_tx, - }; - - tracing::debug!(len=?req.payload.len(), "received local message"); - - // Return the request as a message with the local reply driver - Ok(NextOutput::Message(Message { - driver: Arc::new(local_driver), - payload: req.payload, - reply: Some(reply_subject), - })) - } - std::result::Result::Err(_) => { - tracing::debug!("no local subscription senders"); - - Ok(NextOutput::Unsubscribed) - } - } - } - // Check for regular PostgreSQL messages - msg = self.rx.recv() => { - match msg { - std::result::Result::Ok((payload, reply_subject)) => { - tracing::debug!(len=?payload.len(), "received message"); - - Ok(NextOutput::Message(Message { - driver: Arc::new(self.driver.clone()), - payload, - reply: reply_subject, - })) - } - std::result::Result::Err(_) => { - tracing::debug!(?self.subject, "subscription closed"); - - Ok(NextOutput::Unsubscribed) - } - } - } - } - } else { - // No memory optimization, just poll regular messages - match self.rx.recv().await { - std::result::Result::Ok((payload, reply_subject)) => { - tracing::debug!(len=?payload.len(), "received message"); - - Ok(NextOutput::Message(Message { - driver: Arc::new(self.driver.clone()), - payload, - reply: reply_subject, - })) - } - std::result::Result::Err(_) => { - tracing::debug!("subscription closed"); - - Ok(NextOutput::Unsubscribed) - } + async fn next(&mut self) -> Result { + match self.rx.recv().await { + std::result::Result::Ok(payload) => Ok(DriverOutput::Message { + subject: self.subject.clone(), + payload, + }), + Err(tokio::sync::broadcast::error::RecvError::Closed) => Ok(DriverOutput::Unsubscribed), + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { + // Try again + self.next().await } } } } - -impl Drop for PostgresSubscriber { - fn drop(&mut self) { - tracing::debug!(subject = %self.subject, "dropping postgres subscriber"); - - let driver = self.driver.clone(); - let subject = self.subject.clone(); - let has_local_rx = self.local_request_rx.is_some(); - - // Spawn a task to clean up - tokio::spawn(async move { - // Clean up local subscription registration if memory optimization is enabled - if has_local_rx { - let mut subs = driver.local_subscriptions.write().await; - if let Some(local_sub) = subs.get_mut(&subject) { - // If no more subscriptions for this subject, remove the entry - if local_sub.tx.receiver_count() == 0 { - subs.remove(&subject); - } - } - } - - if let Some(sub) = driver.subscriptions.get(&subject).await { - if sub.tx.receiver_count() == 0 { - driver.subscriptions.invalidate(&subject).await; - - let mut hasher = DefaultHasher::new(); - subject.hash(&mut hasher); - let subject_hash = BASE64.encode(&hasher.finish().to_be_bytes()); - - let sql = format!("UNLISTEN {}", quote_ident(&subject_hash)); - let unlisten_res = driver.client.batch_execute(&sql).await; - - if let std::result::Result::Err(err) = unlisten_res { - tracing::error!(%subject, ?err, "failed to unlisten subject"); - } - } - } - }); - } -} - -fn quote_ident(subject: &str) -> String { - // Double-quote and escape any embedded quotes for safe identifier usage - let escaped = subject.replace('"', "\"\""); - format!("\"{}\"", escaped) -} diff --git a/packages/common/universalpubsub/src/lib.rs b/packages/common/universalpubsub/src/lib.rs index f61c3e913d..8ace233bc8 100644 --- a/packages/common/universalpubsub/src/lib.rs +++ b/packages/common/universalpubsub/src/lib.rs @@ -1,3 +1,4 @@ +pub mod chunking; pub mod driver; pub mod errors; pub mod pubsub; diff --git a/packages/common/universalpubsub/src/pubsub.rs b/packages/common/universalpubsub/src/pubsub.rs index a481e79e50..adc69f2746 100644 --- a/packages/common/universalpubsub/src/pubsub.rs +++ b/packages/common/universalpubsub/src/pubsub.rs @@ -1,26 +1,183 @@ +use std::collections::HashMap; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; use std::time::Duration; use anyhow::*; +use tokio::sync::broadcast; +use tokio::sync::{RwLock, oneshot}; +use uuid::Uuid; -use crate::driver::{PubSubDriverHandle, SubscriberDriverHandle}; +use crate::chunking::{ChunkTracker, encode_chunk, split_payload_into_chunks}; +use crate::driver::{PubSubDriverHandle, PublishOpts, SubscriberDriverHandle}; -#[derive(Clone)] -pub struct PubSub { +pub struct PubSubInner { driver: PubSubDriverHandle, + chunk_tracker: Mutex, + reply_subscribers: RwLock>>>, + // Local in-memory subscribers by subject (shared across all drivers) + local_subscribers: RwLock>>>, + // Enables/disables local fast-path across all drivers + memory_optimization: bool, +} + +#[derive(Clone)] +pub struct PubSub(Arc); + +impl Deref for PubSub { + type Target = PubSubInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } } impl PubSub { pub fn new(driver: PubSubDriverHandle) -> Self { - Self { driver } + Self::new_with_memory_optimization(driver, true) + } + + pub fn new_with_memory_optimization( + driver: PubSubDriverHandle, + memory_optimization: bool, + ) -> Self { + let inner = Arc::new(PubSubInner { + driver, + chunk_tracker: Mutex::new(ChunkTracker::new()), + reply_subscribers: RwLock::new(HashMap::new()), + local_subscribers: RwLock::new(HashMap::new()), + memory_optimization, + }); + + // Spawn GC task for chunk buffers + let gc_inner = Arc::downgrade(&inner); + tokio::spawn(async move { + let mut interval = tokio::time::interval(crate::chunking::CHUNK_BUFFER_GC_INTERVAL); + loop { + interval.tick().await; + if let Some(inner) = gc_inner.upgrade() { + inner.chunk_tracker.lock().unwrap().gc(); + } else { + break; + } + } + }); + + Self(inner) } pub async fn subscribe(&self, subject: &str) -> Result { - let subscriber_driver = self.driver.subscribe(subject).await?; - Ok(Subscriber::new(subscriber_driver)) + // Underlying driver subscription + let driver = self.driver.subscribe(subject).await?; + + if !self.memory_optimization { + return Ok(Subscriber::new(driver, self.clone())); + } + + // Ensure a local broadcast channel exists for this subject + let rx = { + // Try read first for fast path + if let Some(sender) = self.local_subscribers.read().await.get(subject) { + sender.subscribe() + } else { + // Create and insert + let (tx, rx) = broadcast::channel(1024); + let mut map = self.local_subscribers.write().await; + // Double-check in case of race + let rx = if let Some(existing) = map.get(subject) { + existing.subscribe() + } else { + map.insert(subject.to_string(), tx); + rx + }; + rx + } + }; + + // Wrap the driver + let optimized_driver: SubscriberDriverHandle = Box::new(LocalOptimizedSubscriberDriver { + subject: subject.to_string(), + driver, + local_rx: rx, + }); + + Ok(Subscriber::new(optimized_driver, self.clone())) + } + + pub async fn publish(&self, subject: &str, payload: &[u8], opts: PublishOpts) -> Result<()> { + let message_id = *Uuid::new_v4().as_bytes(); + let chunks = + split_payload_into_chunks(payload, self.driver.max_message_size(), message_id, None)?; + let chunk_count = chunks.len() as u32; + + let use_local = self + .should_use_local_subscriber(subject, opts.behavior) + .await; + + for (chunk_idx, chunk_payload) in chunks.into_iter().enumerate() { + let encoded = encode_chunk( + chunk_payload, + chunk_idx as u32, + chunk_count, + message_id, + None, + )?; + + if use_local { + if let Some(sender) = self.local_subscribers.read().await.get(subject) { + let _ = sender.send(encoded); + } else { + tracing::warn!(%subject, "local subscriber disappeared"); + break; + } + } else { + self.driver.publish(subject, &encoded).await?; + } + } + Ok(()) } - pub async fn publish(&self, subject: &str, payload: &[u8]) -> Result<()> { - self.driver.publish(subject, payload).await + pub async fn publish_with_reply( + &self, + subject: &str, + payload: &[u8], + reply_subject: &str, + opts: PublishOpts, + ) -> Result<()> { + let message_id = *Uuid::new_v4().as_bytes(); + let chunks = split_payload_into_chunks( + payload, + self.driver.max_message_size(), + message_id, + Some(reply_subject), + )?; + let chunk_count = chunks.len() as u32; + + let use_local = self + .should_use_local_subscriber(subject, opts.behavior) + .await; + + for (chunk_idx, chunk_payload) in chunks.into_iter().enumerate() { + let encoded = encode_chunk( + chunk_payload, + chunk_idx as u32, + chunk_count, + message_id, + Some(reply_subject.to_string()), + )?; + + if use_local { + if let Some(sender) = self.local_subscribers.read().await.get(subject) { + let _ = sender.send(encoded); + } else { + tracing::warn!(%subject, "local subscriber disappeared"); + break; + } + } else { + self.driver.publish(subject, &encoded).await?; + } + } + Ok(()) } pub async fn flush(&self) -> Result<()> { @@ -28,7 +185,8 @@ impl PubSub { } pub async fn request(&self, subject: &str, payload: &[u8]) -> Result { - self.driver.request(subject, payload, None).await + self.request_with_timeout(subject, payload, Duration::from_secs(30)) + .await } pub async fn request_with_timeout( @@ -37,31 +195,142 @@ impl PubSub { payload: &[u8], timeout: Duration, ) -> Result { - self.driver.request(subject, payload, Some(timeout)).await + // Create a unique reply subject for this request + let reply_subject = format!("_INBOX.{}", Uuid::new_v4()); + + // Create a oneshot channel for the response + let (tx, rx) = oneshot::channel(); + + // Register the reply handler + { + let mut subscribers = self.reply_subscribers.write().await; + subscribers.insert(reply_subject.clone(), tx); + } + + // Subscribe to the reply subject (use local-aware subscribe) + let mut reply_subscriber = self.subscribe(&reply_subject).await?; + + // Send the request with the reply subject, using local fast-path + self.publish_with_reply(subject, payload, &reply_subject, PublishOpts::one()) + .await?; + + // Spawn a task to wait for the reply + let inner = self.0.clone(); + let reply_subject_clone = reply_subject.clone(); + tokio::spawn(async move { + loop { + match reply_subscriber.next().await { + std::result::Result::Ok(NextOutput::Message(msg)) => { + // Already decoded; forward payload + let mut subscribers = inner.reply_subscribers.write().await; + if let Some(tx) = subscribers.remove(&reply_subject_clone) { + let _ = tx.send(msg.payload); + } + break; + } + std::result::Result::Ok(NextOutput::Unsubscribed) + | std::result::Result::Err(_) => break, + } + } + }); + + // Wait for response with timeout + let response = match tokio::time::timeout(timeout, rx).await { + std::result::Result::Ok(std::result::Result::Ok(payload)) => Response { payload }, + std::result::Result::Ok(std::result::Result::Err(_)) => { + // Clean up the reply subscription + self.reply_subscribers.write().await.remove(&reply_subject); + return Err(crate::errors::Ups::RequestTimeout.build().into()); + } + std::result::Result::Err(_) => { + // Timeout elapsed + self.reply_subscribers.write().await.remove(&reply_subject); + return Err(crate::errors::Ups::RequestTimeout.build().into()); + } + }; + + Ok(response) + } + + async fn should_use_local_subscriber( + &self, + subject: &str, + behavior: crate::driver::PublishBehavior, + ) -> bool { + // Local fast-path for one-subscriber behavior: + // - When memory_optimization is enabled and behavior == OneSubscriber, deliver directly + // to any in-process subscribers via the subject's broadcast channel and skip calling + // the underlying driver (avoids network hops and driver overhead). + // - For Broadcast, always publish via the driver so remote subscribers (and other + // processes) receive the message; local subscribers will also receive via the driver. + // - If there are no local receivers at the time of publish (or the channel disappears), + // fall back to the driver publish path. + + if !self.memory_optimization { + return false; + } + if !matches!(behavior, crate::driver::PublishBehavior::OneSubscriber) { + return false; + } + if let Some(sender) = self.local_subscribers.read().await.get(subject) { + sender.receiver_count() > 0 + } else { + false + } } } pub struct Subscriber { driver: SubscriberDriverHandle, + pubsub: PubSub, } impl Subscriber { - pub fn new(driver: SubscriberDriverHandle) -> Self { - Self { driver } + pub fn new(driver: SubscriberDriverHandle, pubsub: PubSub) -> Self { + Self { driver, pubsub } } pub async fn next(&mut self) -> Result { - self.driver.next().await + loop { + match self.driver.next().await? { + DriverOutput::Message { subject, payload } => { + // Process chunks + let mut tracker = self.pubsub.chunk_tracker.lock().unwrap(); + match tracker.process_chunk(&payload) { + std::result::Result::Ok(Some((payload, reply_subject))) => { + return Ok(NextOutput::Message(Message { + pubsub: self.pubsub.clone(), + payload, + reply: reply_subject, + })); + } + std::result::Result::Ok(None) => continue, // Waiting for more chunks + std::result::Result::Err(e) => { + tracing::warn!(?e, "failed to process chunk"); + continue; + } + } + } + DriverOutput::Unsubscribed => return Ok(NextOutput::Unsubscribed), + } + } } } +// Output from drivers (raw binary messages) +pub enum DriverOutput { + Message { subject: String, payload: Vec }, + Unsubscribed, +} + +// Output from subscriber (after chunking/decoding) pub enum NextOutput { Message(Message), Unsubscribed, } pub struct Message { - pub driver: PubSubDriverHandle, + pub pubsub: PubSub, pub payload: Vec, pub reply: Option, } @@ -69,13 +338,55 @@ pub struct Message { impl Message { pub async fn reply(&self, payload: &[u8]) -> Result<()> { if let Some(ref reply_subject) = self.reply { - self.driver.send_request_reply(reply_subject, payload).await - } else { - Ok(()) + // Send reply using chunking + let message_id = *uuid::Uuid::new_v4().as_bytes(); + // Replies expect exactly one subscriber and should use local fast-path + self.pubsub + .publish(reply_subject, payload, PublishOpts::one()) + .await?; } + Ok(()) } } pub struct Response { pub payload: Vec, } + +/// Internal composite subscriber that merges driver messages with local in-memory messages +struct LocalOptimizedSubscriberDriver { + subject: String, + driver: SubscriberDriverHandle, + local_rx: broadcast::Receiver>, +} + +#[async_trait::async_trait] +impl crate::driver::SubscriberDriver for LocalOptimizedSubscriberDriver { + async fn next(&mut self) -> Result { + loop { + tokio::select! { + biased; + // Prefer local messages to reduce latency + res = self.local_rx.recv() => { + match res { + std::result::Result::Ok(payload) => { + return Ok(DriverOutput::Message { subject: self.subject.clone(), payload }); + } + std::result::Result::Err(broadcast::error::RecvError::Lagged(_)) => { + // Skip lagged and continue + continue; + } + std::result::Result::Err(broadcast::error::RecvError::Closed) => { + // Local channel closed; fall back to driver only + // Replace with a closed receiver to avoid busy loop + // We simply continue and rely on driver + } + } + } + res = self.driver.next() => { + return res; + } + } + } + } +} diff --git a/packages/common/universalpubsub/tests/integration.rs b/packages/common/universalpubsub/tests/integration.rs index 8541cfbbfe..af38066eeb 100644 --- a/packages/common/universalpubsub/tests/integration.rs +++ b/packages/common/universalpubsub/tests/integration.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use universalpubsub::{NextOutput, PubSub}; +use universalpubsub::{NextOutput, PubSub, PublishOpts}; use uuid::Uuid; fn setup_logging() { @@ -18,7 +18,7 @@ fn setup_logging() { } #[tokio::test] -async fn test_nats_driver() { +async fn test_nats_driver_with_memory() { setup_logging(); let test_id = Uuid::new_v4(); @@ -46,7 +46,41 @@ async fn test_nats_driver() { ) .await .unwrap(); - let pubsub = PubSub::new(Arc::new(driver)); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); + + test_inner(&pubsub).await; +} + +#[tokio::test] +async fn test_nats_driver_without_memory() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (pubsub_config, docker_config) = TestPubSub::Nats.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + let rivet_config::config::PubSub::Nats(nats) = pubsub_config else { + unreachable!(); + }; + + use std::str::FromStr; + let server_addrs = nats + .addresses + .iter() + .map(|addr| format!("nats://{addr}")) + .map(|url| async_nats::ServerAddr::from_str(url.as_ref())) + .collect::, _>>() + .unwrap(); + + let driver = universalpubsub::driver::nats::NatsDriver::connect( + async_nats::ConnectOptions::new(), + &server_addrs[..], + ) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); test_inner(&pubsub).await; } @@ -69,7 +103,7 @@ async fn test_postgres_driver_with_memory() { let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, true) .await .unwrap(); - let pubsub = PubSub::new(Arc::new(driver)); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); test_inner(&pubsub).await; } @@ -92,7 +126,7 @@ async fn test_postgres_driver_without_memory() { let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, false) .await .unwrap(); - let pubsub = PubSub::new(Arc::new(driver)); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); test_inner(&pubsub).await; } @@ -137,6 +171,10 @@ async fn test_inner(pubsub: &PubSub) { let start = Instant::now(); test_request_timeout(&pubsub).await.unwrap(); tracing::info!(duration_ms = ?start.elapsed().as_millis(), "test_request_timeout completed"); + + let start = Instant::now(); + test_large_payloads(&pubsub).await.unwrap(); + tracing::info!(duration_ms = ?start.elapsed().as_millis(), "test_large_payloads completed"); } async fn test_basic_pub_sub(pubsub: &PubSub) -> Result<()> { @@ -147,7 +185,9 @@ async fn test_basic_pub_sub(pubsub: &PubSub) -> Result<()> { // Publish a message let message = b"Hello, World!"; - pubsub.publish("test.subject", message).await?; + pubsub + .publish("test.subject", message, PublishOpts::broadcast()) + .await?; pubsub.flush().await?; // Receive the message @@ -172,7 +212,9 @@ async fn test_multiple_subscribers(pubsub: &PubSub) -> Result<()> { // Publish a message let message = b"Broadcast message"; - pubsub.publish("test.multi", message).await?; + pubsub + .publish("test.multi", message, PublishOpts::broadcast()) + .await?; pubsub.flush().await?; // Both subscribers should receive the message @@ -205,7 +247,9 @@ async fn test_unsubscribe(pubsub: &PubSub) -> Result<()> { // Publish a message let message = b"First message"; - pubsub.publish("test.unsub", message).await?; + pubsub + .publish("test.unsub", message, PublishOpts::broadcast()) + .await?; pubsub.flush().await?; // Receive the first message @@ -226,7 +270,9 @@ async fn test_unsubscribe(pubsub: &PubSub) -> Result<()> { // Publish another message let message2 = b"Second message"; - pubsub.publish("test.unsub", message2).await?; + pubsub + .publish("test.unsub", message2, PublishOpts::broadcast()) + .await?; pubsub.flush().await?; // New subscriber should receive the message @@ -328,3 +374,68 @@ async fn test_request_timeout(pubsub: &PubSub) -> Result<()> { Ok(()) } + +async fn test_large_payloads(pubsub: &PubSub) -> Result<()> { + tracing::info!("testing large payloads with chunking"); + + // Use a base size that works with all drivers + // Postgres has the smallest limit at 8KB + let base_size = 5000; // Use 5KB as base to ensure we're under the limit + + // Test 1x max size + test_payload_size(pubsub, base_size, "1x").await?; + + // Test 2x max size + test_payload_size(pubsub, base_size * 2, "2x").await?; + + // Test 2.5x max size + test_payload_size(pubsub, (base_size as f64 * 2.5) as usize, "2.5x").await?; + + Ok(()) +} + +async fn test_payload_size(pubsub: &PubSub, size: usize, label: &str) -> Result<()> { + tracing::info!(size, label, "testing payload size"); + + // Create a payload of the specified size + let mut payload = vec![0u8; size]; + // Fill with a pattern to verify integrity + for (i, byte) in payload.iter_mut().enumerate() { + *byte = (i % 256) as u8; + } + + // Subscribe to a test subject + let mut subscriber = pubsub.subscribe(&format!("test.large.{}", label)).await?; + + // Publish the large message + pubsub + .publish( + &format!("test.large.{}", label), + &payload, + PublishOpts::broadcast(), + ) + .await?; + pubsub.flush().await?; + + // Receive and verify the message + match subscriber.next().await? { + NextOutput::Message(msg) => { + assert_eq!( + msg.payload.len(), + size, + "payload size mismatch for {}", + label + ); + assert_eq!( + msg.payload, payload, + "payload content mismatch for {}", + label + ); + } + NextOutput::Unsubscribed => { + panic!("unexpected unsubscribe for {}", label); + } + } + + Ok(()) +} diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index f7297a0087..09a960fe88 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use gas::prelude::*; use http_body_util::{BodyExt, Full}; -use hyper::{body::Incoming as BodyIncoming, Request, Response, StatusCode}; +use hyper::{Request, Response, StatusCode, body::Incoming as BodyIncoming}; use hyper_tungstenite::HyperWebsocket; use pegboard::pubsub_subjects::{ TunnelHttpResponseSubject, TunnelHttpRunnerSubject, TunnelHttpWebSocketSubject, @@ -16,21 +16,22 @@ use rivet_guard_core::{ request_context::RequestContext, }; use rivet_tunnel_protocol::{ - versioned, MessageBody, StreamFinishReason, ToServerRequestFinish, ToServerRequestStart, + MessageBody, StreamFinishReason, ToServerRequestFinish, ToServerRequestStart, ToServerWebSocketClose, ToServerWebSocketMessage, ToServerWebSocketOpen, TunnelMessage, + versioned, }; use rivet_util::serde::HashableMap; use std::result::Result::Ok as ResultOk; use std::{ collections::HashMap, sync::{ - atomic::{AtomicU64, Ordering}, Arc, + atomic::{AtomicU64, Ordering}, }, time::Duration, }; use tokio::{ - sync::{oneshot, Mutex}, + sync::{Mutex, oneshot}, time::timeout, }; use tokio_tungstenite::tungstenite::Message; diff --git a/packages/core/pegboard-tunnel/src/lib.rs b/packages/core/pegboard-tunnel/src/lib.rs index 1a4a5df1bb..37ddc97cb8 100644 --- a/packages/core/pegboard-tunnel/src/lib.rs +++ b/packages/core/pegboard-tunnel/src/lib.rs @@ -9,10 +9,10 @@ use gas::prelude::*; use http_body_util::Full; use hyper::body::{Bytes, Incoming as BodyIncoming}; use hyper::{Request, Response, StatusCode}; -use hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode as WsCloseCode; -use hyper_tungstenite::tungstenite::protocol::frame::CloseFrame as WsCloseFrame; use hyper_tungstenite::tungstenite::Utf8Bytes as WsUtf8Bytes; -use hyper_tungstenite::{tungstenite::Message as WsMessage, HyperWebsocket}; +use hyper_tungstenite::tungstenite::protocol::frame::CloseFrame as WsCloseFrame; +use hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode as WsCloseCode; +use hyper_tungstenite::{HyperWebsocket, tungstenite::Message as WsMessage}; use pegboard::pubsub_subjects::{ TunnelHttpResponseSubject, TunnelHttpRunnerSubject, TunnelHttpWebSocketSubject, }; @@ -20,7 +20,7 @@ use rivet_guard_core::custom_serve::CustomServeTrait; use rivet_guard_core::proxy_service::ResponseBody; use rivet_guard_core::request_context::RequestContext; use rivet_pools::Pools; -use rivet_tunnel_protocol::{versioned, MessageBody, TunnelMessage}; +use rivet_tunnel_protocol::{MessageBody, TunnelMessage, versioned}; use rivet_util::Id; use std::net::SocketAddr; use tokio::net::TcpListener; diff --git a/packages/services/pegboard/src/workflows/actor/actor_keys.rs b/packages/services/pegboard/src/workflows/actor/actor_keys.rs index ad01e160c5..1ff028160a 100644 --- a/packages/services/pegboard/src/workflows/actor/actor_keys.rs +++ b/packages/services/pegboard/src/workflows/actor/actor_keys.rs @@ -6,7 +6,7 @@ use futures_util::TryStreamExt; use gas::prelude::*; use rivet_key_data::converted::ActorByKeyKeyData; use udb_util::prelude::*; -use universaldb::{self as udb, options::StreamingMode, FdbBindingError}; +use universaldb::{self as udb, FdbBindingError, options::StreamingMode}; use crate::keys; diff --git a/sdks/rust/tunnel-protocol/Cargo.toml b/sdks/rust/tunnel-protocol/Cargo.toml index 9591823bf6..f7bc19f63d 100644 --- a/sdks/rust/tunnel-protocol/Cargo.toml +++ b/sdks/rust/tunnel-protocol/Cargo.toml @@ -8,7 +8,6 @@ edition.workspace = true [dependencies] anyhow.workspace = true base64.workspace = true -gas.workspace = true rivet-util.workspace = true serde_bare.workspace = true serde.workspace = true diff --git a/sdks/rust/ups-protocol/Cargo.toml b/sdks/rust/ups-protocol/Cargo.toml new file mode 100644 index 0000000000..17cdbee161 --- /dev/null +++ b/sdks/rust/ups-protocol/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "rivet-ups-protocol" +version.workspace = true +authors.workspace = true +license.workspace = true +edition.workspace = true + +[dependencies] +anyhow.workspace = true +base64.workspace = true +rivet-util.workspace = true +serde_bare.workspace = true +serde.workspace = true +versioned-data-util.workspace = true + +[build-dependencies] +bare_gen.workspace = true +indoc.workspace = true +prettyplease.workspace = true +serde_json.workspace = true +syn.workspace = true \ No newline at end of file diff --git a/sdks/rust/ups-protocol/build.rs b/sdks/rust/ups-protocol/build.rs new file mode 100644 index 0000000000..6c5936b1e0 --- /dev/null +++ b/sdks/rust/ups-protocol/build.rs @@ -0,0 +1,71 @@ +use std::{env, fs, path::Path}; + +use indoc::formatdoc; + +mod rust { + use super::*; + + pub fn generate_sdk(schema_dir: &Path) { + let out_dir = env::var("OUT_DIR").unwrap(); + let out_path = Path::new(&out_dir); + let mut all_names = Vec::new(); + + for entry in fs::read_dir(schema_dir).unwrap().flatten() { + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let bare_name = path + .file_name() + .unwrap() + .to_str() + .unwrap() + .rsplit_once('.') + .unwrap() + .0; + + let content = + prettyplease::unparse(&syn::parse2(bare_gen::bare_schema(&path)).unwrap()); + fs::write(out_path.join(format!("{bare_name}_generated.rs")), content).unwrap(); + + all_names.push(bare_name.to_string()); + } + + let mut mod_content = String::new(); + mod_content.push_str("// Auto-generated module file for BARE schemas\n\n"); + + // Generate module declarations for each version + for name in all_names { + mod_content.push_str(&formatdoc!( + r#" + pub mod {name} {{ + include!(concat!(env!("OUT_DIR"), "/{name}_generated.rs")); + }} + "#, + )); + } + + let mod_file_path = out_path.join("combined_imports.rs"); + fs::write(&mod_file_path, mod_content).expect("Failed to write combined_imports.rs"); + } +} + +fn main() { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let workspace_root = Path::new(&manifest_dir) + .parent() + .and_then(|p| p.parent()) + .and_then(|p| p.parent()) + .expect("Failed to find workspace root"); + + let schema_dir = workspace_root + .join("sdks") + .join("schemas") + .join("ups-protocol"); + + println!("cargo:rerun-if-changed={}", schema_dir.display()); + + rust::generate_sdk(&schema_dir); +} diff --git a/sdks/rust/ups-protocol/src/generated.rs b/sdks/rust/ups-protocol/src/generated.rs new file mode 100644 index 0000000000..84801af8dc --- /dev/null +++ b/sdks/rust/ups-protocol/src/generated.rs @@ -0,0 +1 @@ +include!(concat!(env!("OUT_DIR"), "/combined_imports.rs")); diff --git a/sdks/rust/ups-protocol/src/lib.rs b/sdks/rust/ups-protocol/src/lib.rs new file mode 100644 index 0000000000..676c99e464 --- /dev/null +++ b/sdks/rust/ups-protocol/src/lib.rs @@ -0,0 +1,7 @@ +pub mod generated; +pub mod versioned; + +// Re-export latest +pub use generated::v1::*; + +pub const PROTOCOL_VERSION: u16 = 1; diff --git a/sdks/rust/ups-protocol/src/versioned.rs b/sdks/rust/ups-protocol/src/versioned.rs new file mode 100644 index 0000000000..4a739b2f2e --- /dev/null +++ b/sdks/rust/ups-protocol/src/versioned.rs @@ -0,0 +1,48 @@ +use anyhow::{Ok, Result, bail}; +use versioned_data_util::OwnedVersionedData; + +use crate::{PROTOCOL_VERSION, generated::v1}; + +pub enum UpsMessage { + V1(v1::UpsMessage), +} + +impl OwnedVersionedData for UpsMessage { + type Latest = v1::UpsMessage; + + fn latest(latest: v1::UpsMessage) -> Self { + UpsMessage::V1(latest) + } + + fn into_latest(self) -> Result { + #[allow(irrefutable_let_patterns)] + if let UpsMessage::V1(data) = self { + Ok(data) + } else { + bail!("version not latest"); + } + } + + fn deserialize_version(payload: &[u8], version: u16) -> Result { + match version { + 1 => Ok(UpsMessage::V1(serde_bare::from_slice(payload)?)), + _ => bail!("invalid version: {version}"), + } + } + + fn serialize_version(self, _version: u16) -> Result> { + match self { + UpsMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), + } + } +} + +impl UpsMessage { + pub fn deserialize(buf: &[u8]) -> Result { + ::deserialize(buf, PROTOCOL_VERSION) + } + + pub fn serialize(self) -> Result> { + ::serialize(self, PROTOCOL_VERSION) + } +} diff --git a/sdks/schemas/ups-protocol/v1.bare b/sdks/schemas/ups-protocol/v1.bare new file mode 100644 index 0000000000..426e200789 --- /dev/null +++ b/sdks/schemas/ups-protocol/v1.bare @@ -0,0 +1,23 @@ +type UuidV4 data[16] + +type MessageStart struct { + message_id: UuidV4 + chunk_count: u32 + reply_subject: optional + payload: data +} + +type MessageChunk struct { + message_id: UuidV4 + chunk_index: u32 + payload: data +} + +type MessageBody union { + MessageStart | + MessageChunk +} + +type UpsMessage struct { + body: MessageBody +}