From 9d88491f2f5fd581aef1e82f391e653a887fea34 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sat, 6 Sep 2025 23:44:02 -0700 Subject: [PATCH 1/5] chore(pegboard): consolidate to single subscriber per gateway --- Cargo.lock | 6 +- packages/common/universalpubsub/Cargo.toml | 2 +- .../src/driver/postgres/mod.rs | 106 ++- packages/core/guard/server/Cargo.toml | 1 + packages/core/guard/server/src/lib.rs | 7 +- packages/core/guard/server/src/routing/mod.rs | 17 +- .../server/src/routing/pegboard_gateway.rs | 26 +- .../server/src/routing/pegboard_tunnel.rs | 5 +- .../core/guard/server/src/shared_state.rs | 32 + packages/core/pegboard-gateway/Cargo.toml | 6 +- packages/core/pegboard-gateway/src/lib.rs | 427 ++++------ .../core/pegboard-gateway/src/shared_state.rs | 286 +++++++ packages/core/pegboard-runner-ws/src/lib.rs | 45 +- packages/core/pegboard-tunnel/Cargo.toml | 3 +- packages/core/pegboard-tunnel/src/lib.rs | 779 ++++-------------- .../core/pegboard-tunnel/tests/integration.rs | 13 +- .../infra/engine/tests/actors_lifecycle.rs | 9 +- .../pegboard/src/ops/runner/get_by_key.rs | 51 ++ .../services/pegboard/src/ops/runner/mod.rs | 1 + .../services/pegboard/src/pubsub_subjects.rs | 68 +- pnpm-lock.yaml | 9 + sdks/rust/runner-protocol/src/protocol.rs | 2 - sdks/rust/runner-protocol/src/versioned.rs | 2 - sdks/rust/tunnel-protocol/build.rs | 89 +- sdks/rust/tunnel-protocol/src/versioned.rs | 66 +- sdks/schemas/runner-protocol/v1.bare | 2 - sdks/schemas/tunnel-protocol/v1.bare | 66 +- sdks/typescript/runner-protocol/src/index.ts | 157 ++-- sdks/typescript/runner/package.json | 1 + sdks/typescript/runner/src/mod.ts | 91 +- sdks/typescript/runner/src/tunnel.ts | 774 ++++++++++------- .../runner/src/websocket-tunnel-adapter.ts | 4 +- sdks/typescript/tunnel-protocol/src/index.ts | 293 +++---- 33 files changed, 1783 insertions(+), 1663 deletions(-) create mode 100644 packages/core/guard/server/src/shared_state.rs create mode 100644 packages/core/pegboard-gateway/src/shared_state.rs create mode 100644 packages/services/pegboard/src/ops/runner/get_by_key.rs diff --git a/Cargo.lock b/Cargo.lock index bcf2da44ee..800d4ca464 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3270,9 +3270,11 @@ dependencies = [ "rivet-guard-core", "rivet-tunnel-protocol", "rivet-util", + "thiserror 1.0.69", "tokio", "tokio-tungstenite", "universalpubsub", + "versioned-data-util", ] [[package]] @@ -3304,6 +3306,7 @@ version = "0.0.1" dependencies = [ "anyhow", "async-trait", + "bytes", "futures", "gasoline", "http-body-util", @@ -4340,6 +4343,7 @@ dependencies = [ "tracing", "udb-util", "universaldb", + "universalpubsub", "url", "uuid", ] @@ -6329,7 +6333,6 @@ dependencies = [ "base64 0.22.1", "deadpool-postgres", "futures-util", - "moka", "rivet-config", "rivet-env", "rivet-error", @@ -6342,6 +6345,7 @@ dependencies = [ "tempfile", "tokio", "tokio-postgres", + "tokio-util", "tracing", "tracing-subscriber", "uuid", diff --git a/packages/common/universalpubsub/Cargo.toml b/packages/common/universalpubsub/Cargo.toml index 37028f9995..3cdb8509b6 100644 --- a/packages/common/universalpubsub/Cargo.toml +++ b/packages/common/universalpubsub/Cargo.toml @@ -12,7 +12,6 @@ async-trait.workspace = true base64.workspace = true deadpool-postgres.workspace = true futures-util.workspace = true -moka.workspace = true rivet-error.workspace = true rivet-ups-protocol.workspace = true serde_json.workspace = true @@ -21,6 +20,7 @@ serde.workspace = true sha2.workspace = true tokio-postgres.workspace = true tokio.workspace = true +tokio-util.workspace = true tracing.workspace = true uuid.workspace = true diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index 99cb703706..c2df9a36b4 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -1,5 +1,6 @@ +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use anyhow::*; use async_trait::async_trait; @@ -7,7 +8,6 @@ use base64::Engine; use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; -use moka::future::Cache; use tokio_postgres::{AsyncMessage, NoTls}; use tracing::Instrument; @@ -18,6 +18,15 @@ use crate::pubsub::DriverOutput; struct Subscription { // Channel to send messages to this subscription tx: tokio::sync::broadcast::Sender>, + // Cancellation token shared by all subscribers of this subject + token: tokio_util::sync::CancellationToken, +} + +impl Subscription { + fn new(tx: tokio::sync::broadcast::Sender>) -> Self { + let token = tokio_util::sync::CancellationToken::new(); + Self { tx, token } + } } /// > In the default configuration it must be shorter than 8000 bytes @@ -40,7 +49,7 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize = pub struct PostgresDriver { pool: Arc, client: Arc, - subscriptions: Cache, + subscriptions: Arc>>, } impl PostgresDriver { @@ -65,8 +74,8 @@ impl PostgresDriver { .context("failed to create postgres pool")?; tracing::debug!("postgres pool created successfully"); - let subscriptions: Cache = - Cache::builder().initial_capacity(5).build(); + let subscriptions: Arc>> = + Arc::new(Mutex::new(HashMap::new())); let subscriptions2 = subscriptions.clone(); let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?; @@ -75,7 +84,9 @@ impl PostgresDriver { loop { match poll_fn(|cx| conn.poll_message(cx)).await { Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - if let Some(sub) = subscriptions2.get(note.channel()).await { + if let Some(sub) = + subscriptions2.lock().unwrap().get(note.channel()).cloned() + { let bytes = match BASE64.decode(note.payload()) { std::result::Result::Ok(b) => b, std::result::Result::Err(err) => { @@ -121,7 +132,7 @@ impl PostgresDriver { #[async_trait] impl PubSubDriver for PostgresDriver { async fn subscribe(&self, subject: &str) -> Result { - // TODO: To match NATS implementation, LIST must be pipelined (i.e. wait for the command + // TODO: To match NATS implementation, LISTEN must be pipelined (i.e. wait for the command // to reach the server, but not wait for it to respond). However, this has to ensure that // NOTIFY & LISTEN are called on the same connection (not diff connections in a pool) or // else there will be race conditions where messages might be published before @@ -135,33 +146,57 @@ impl PubSubDriver for PostgresDriver { let hashed = self.hash_subject(subject); // Check if we already have a subscription for this channel - let rx = if let Some(existing_sub) = self.subscriptions.get(&hashed).await { - // Reuse the existing broadcast channel - existing_sub.tx.subscribe() - } else { - // Create a new broadcast channel for this subject - let (tx, rx) = tokio::sync::broadcast::channel(1024); - let subscription = Subscription { tx: tx.clone() }; - - // Register subscription - self.subscriptions - .insert(hashed.clone(), subscription) - .await; - - // Execute LISTEN command on the async client (for receiving notifications) - // This only needs to be done once per channel - let span = tracing::trace_span!("pg_listen"); - self.client - .execute(&format!("LISTEN \"{hashed}\""), &[]) - .instrument(span) - .await?; - - rx - }; + let (rx, drop_guard) = + if let Some(existing_sub) = self.subscriptions.lock().unwrap().get(&hashed).cloned() { + // Reuse the existing broadcast channel + let rx = existing_sub.tx.subscribe(); + let drop_guard = existing_sub.token.clone().drop_guard(); + (rx, drop_guard) + } else { + // Create a new broadcast channel for this subject + let (tx, rx) = tokio::sync::broadcast::channel(1024); + let subscription = Subscription::new(tx.clone()); + + // Register subscription + self.subscriptions + .lock() + .unwrap() + .insert(hashed.clone(), subscription.clone()); + + // Execute LISTEN command on the async client (for receiving notifications) + // This only needs to be done once per channel + let span = tracing::trace_span!("pg_listen"); + self.client + .execute(&format!("LISTEN \"{hashed}\""), &[]) + .instrument(span) + .await?; + + // Spawn a single cleanup task for this subscription waiting on its token + let driver = self.clone(); + let hashed_clone = hashed.clone(); + let tx_clone = tx.clone(); + let token_clone = subscription.token.clone(); + tokio::spawn(async move { + token_clone.cancelled().await; + if tx_clone.receiver_count() == 0 { + let sql = format!("UNLISTEN \"{}\"", hashed_clone); + if let Err(err) = driver.client.execute(sql.as_str(), &[]).await { + tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel"); + } else { + tracing::trace!(%hashed_clone, "unlistened channel"); + } + driver.subscriptions.lock().unwrap().remove(&hashed_clone); + } + }); + + let drop_guard = subscription.token.clone().drop_guard(); + (rx, drop_guard) + }; Ok(Box::new(PostgresSubscriber { subject: subject.to_string(), - rx, + rx: Some(rx), + _drop_guard: drop_guard, })) } @@ -191,13 +226,18 @@ impl PubSubDriver for PostgresDriver { pub struct PostgresSubscriber { subject: String, - rx: tokio::sync::broadcast::Receiver>, + rx: Option>>, + _drop_guard: tokio_util::sync::DropGuard, } #[async_trait] impl SubscriberDriver for PostgresSubscriber { async fn next(&mut self) -> Result { - match self.rx.recv().await { + let rx = match self.rx.as_mut() { + Some(rx) => rx, + None => return Ok(DriverOutput::Unsubscribed), + }; + match rx.recv().await { std::result::Result::Ok(payload) => Ok(DriverOutput::Message { subject: self.subject.clone(), payload, diff --git a/packages/core/guard/server/Cargo.toml b/packages/core/guard/server/Cargo.toml index fbca49ce9f..ef8695c6c2 100644 --- a/packages/core/guard/server/Cargo.toml +++ b/packages/core/guard/server/Cargo.toml @@ -20,6 +20,7 @@ hyper-tungstenite.workspace = true tower.workspace = true udb-util.workspace = true universaldb.workspace = true +universalpubsub.workspace = true futures.workspace = true # TODO: Make this use workspace version hyper = "1.6.0" diff --git a/packages/core/guard/server/src/lib.rs b/packages/core/guard/server/src/lib.rs index 7f3c6e4626..f7533372e3 100644 --- a/packages/core/guard/server/src/lib.rs +++ b/packages/core/guard/server/src/lib.rs @@ -5,6 +5,7 @@ pub mod cache; pub mod errors; pub mod middleware; pub mod routing; +pub mod shared_state; pub mod tls; #[tracing::instrument(skip_all)] @@ -26,8 +27,12 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R tracing::warn!("crypto provider already installed in this process"); } + // Share shared context + let shared_state = shared_state::SharedState::new(ctx.ups()?); + shared_state.start().await?; + // Create handlers - let routing_fn = routing::create_routing_function(ctx.clone()); + let routing_fn = routing::create_routing_function(ctx.clone(), shared_state.clone()); let cache_key_fn = cache::create_cache_key_function(ctx.clone()); let middleware_fn = middleware::create_middleware_function(ctx.clone()); let cert_resolver = tls::create_cert_resolver(&ctx).await?; diff --git a/packages/core/guard/server/src/routing/mod.rs b/packages/core/guard/server/src/routing/mod.rs index b8b0d757f6..524d63075d 100644 --- a/packages/core/guard/server/src/routing/mod.rs +++ b/packages/core/guard/server/src/routing/mod.rs @@ -5,7 +5,7 @@ use gas::prelude::*; use hyper::header::HeaderName; use rivet_guard_core::RoutingFn; -use crate::errors; +use crate::{errors, shared_state::SharedState}; mod api_peer; mod api_public; @@ -17,13 +17,14 @@ pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-t /// Creates the main routing function that handles all incoming requests #[tracing::instrument(skip_all)] -pub fn create_routing_function(ctx: StandaloneCtx) -> RoutingFn { +pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> RoutingFn { Arc::new( move |hostname: &str, path: &str, port_type: rivet_guard_core::proxy_service::PortType, headers: &hyper::HeaderMap| { let ctx = ctx.clone(); + let shared_state = shared_state.clone(); Box::pin( async move { @@ -41,9 +42,15 @@ pub fn create_routing_function(ctx: StandaloneCtx) -> RoutingFn { return Ok(routing_output); } - if let Some(routing_output) = - pegboard_gateway::route_request(&ctx, target, host, path, headers) - .await? + if let Some(routing_output) = pegboard_gateway::route_request( + &ctx, + &shared_state, + target, + host, + path, + headers, + ) + .await? { return Ok(routing_output); } diff --git a/packages/core/guard/server/src/routing/pegboard_gateway.rs b/packages/core/guard/server/src/routing/pegboard_gateway.rs index fe8a5b2fe5..58088a1ac4 100644 --- a/packages/core/guard/server/src/routing/pegboard_gateway.rs +++ b/packages/core/guard/server/src/routing/pegboard_gateway.rs @@ -6,7 +6,7 @@ use hyper::header::HeaderName; use rivet_guard_core::proxy_service::{RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout}; use udb_util::{SERIALIZABLE, TxnExt}; -use crate::errors; +use crate::{errors, shared_state::SharedState}; const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10); pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); @@ -16,6 +16,7 @@ pub const X_RIVET_PORT: HeaderName = HeaderName::from_static("x-rivet-port"); #[tracing::instrument(skip_all)] pub async fn route_request( ctx: &StandaloneCtx, + shared_state: &SharedState, target: &str, _host: &str, path: &str, @@ -73,7 +74,7 @@ pub async fn route_request( let port_name = port_name.to_str()?; // Lookup actor - find_actor(ctx, actor_id, port_name, path).await + find_actor(ctx, shared_state, actor_id, port_name, path).await } struct FoundActor { @@ -86,6 +87,7 @@ struct FoundActor { #[tracing::instrument(skip_all, fields(%actor_id, %port_name, %path))] async fn find_actor( ctx: &StandaloneCtx, + shared_state: &SharedState, actor_id: Id, port_name: &str, path: &str, @@ -158,10 +160,10 @@ async fn find_actor( actor_ids: vec![actor_id], }); let res = tokio::time::timeout(Duration::from_secs(5), get_runner_fut).await??; - let runner_info = res.actors.into_iter().next().filter(|x| x.is_connectable); + let actor = res.actors.into_iter().next().filter(|x| x.is_connectable); - let runner_id = if let Some(runner_info) = runner_info { - runner_info.runner_id + let runner_id = if let Some(actor) = actor { + actor.runner_id } else { tracing::info!(?actor_id, "waiting for actor to become ready"); @@ -185,11 +187,23 @@ async fn find_actor( tracing::debug!(?actor_id, ?runner_id, "actor ready"); + // Get runner key from runner_id + let runner_key = ctx + .udb()? + .run(|tx, _mc| async move { + let txs = tx.subspace(pegboard::keys::subspace()); + let key_key = pegboard::keys::runner::KeyKey::new(runner_id); + txs.read_opt(&key_key, SERIALIZABLE).await + }) + .await? + .context("runner key not found")?; + // Return pegboard-gateway instance let gateway = pegboard_gateway::PegboardGateway::new( ctx.clone(), + shared_state.pegboard_gateway.clone(), actor_id, - runner_id, + runner_key, port_name.to_string(), ); Ok(Some(RoutingOutput::CustomServe(std::sync::Arc::new( diff --git a/packages/core/guard/server/src/routing/pegboard_tunnel.rs b/packages/core/guard/server/src/routing/pegboard_tunnel.rs index a3b6b0eaad..7d69150a03 100644 --- a/packages/core/guard/server/src/routing/pegboard_tunnel.rs +++ b/packages/core/guard/server/src/routing/pegboard_tunnel.rs @@ -12,13 +12,10 @@ pub async fn route_request( _host: &str, _path: &str, ) -> Result> { - // Check target if target != "tunnel" { return Ok(None); } - // Create pegboard-tunnel service instance - let tunnel = pegboard_tunnel::PegboardTunnelCustomServe::new(ctx.clone()).await?; - + let tunnel = pegboard_tunnel::PegboardTunnelCustomServe::new(ctx.clone()); Ok(Some(RoutingOutput::CustomServe(Arc::new(tunnel)))) } diff --git a/packages/core/guard/server/src/shared_state.rs b/packages/core/guard/server/src/shared_state.rs new file mode 100644 index 0000000000..71bc25939e --- /dev/null +++ b/packages/core/guard/server/src/shared_state.rs @@ -0,0 +1,32 @@ +use anyhow::*; +use gas::prelude::*; +use std::{ops::Deref, sync::Arc}; +use universalpubsub::PubSub; + +#[derive(Clone)] +pub struct SharedState(Arc); + +impl SharedState { + pub fn new(pubsub: PubSub) -> SharedState { + SharedState(Arc::new(SharedStateInner { + pegboard_gateway: pegboard_gateway::shared_state::SharedState::new(pubsub), + })) + } + + pub async fn start(&self) -> Result<()> { + self.pegboard_gateway.start().await?; + Ok(()) + } +} + +impl Deref for SharedState { + type Target = SharedStateInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub struct SharedStateInner { + pub pegboard_gateway: pegboard_gateway::shared_state::SharedState, +} diff --git a/packages/core/pegboard-gateway/Cargo.toml b/packages/core/pegboard-gateway/Cargo.toml index aeeb32bc53..6a3ddc446a 100644 --- a/packages/core/pegboard-gateway/Cargo.toml +++ b/packages/core/pegboard-gateway/Cargo.toml @@ -14,12 +14,14 @@ gas.workspace = true http-body-util.workspace = true hyper = "1.6" hyper-tungstenite.workspace = true -pegboard = { path = "../../services/pegboard" } +pegboard.workspace = true rand.workspace = true rivet-error.workspace = true -rivet-guard-core = { path = "../guard/core" } +rivet-guard-core.workspace = true rivet-tunnel-protocol.workspace = true rivet-util.workspace = true tokio-tungstenite.workspace = true tokio.workspace = true universalpubsub.workspace = true +versioned-data-util.workspace = true +thiserror.workspace = true diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index 09a960fe88..8d752eba42 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -4,56 +4,48 @@ use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use gas::prelude::*; use http_body_util::{BodyExt, Full}; -use hyper::{Request, Response, StatusCode, body::Incoming as BodyIncoming}; +use hyper::{Request, Response, StatusCode}; use hyper_tungstenite::HyperWebsocket; -use pegboard::pubsub_subjects::{ - TunnelHttpResponseSubject, TunnelHttpRunnerSubject, TunnelHttpWebSocketSubject, -}; -use rivet_error::*; use rivet_guard_core::{ custom_serve::CustomServeTrait, proxy_service::{ResponseBody, X_RIVET_ERROR}, request_context::RequestContext, }; use rivet_tunnel_protocol::{ - MessageBody, StreamFinishReason, ToServerRequestFinish, ToServerRequestStart, - ToServerWebSocketClose, ToServerWebSocketMessage, ToServerWebSocketOpen, TunnelMessage, - versioned, + MessageKind, ToServerRequestStart, ToServerWebSocketClose, ToServerWebSocketMessage, + ToServerWebSocketOpen, }; use rivet_util::serde::HashableMap; -use std::result::Result::Ok as ResultOk; -use std::{ - collections::HashMap, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, - time::Duration, -}; -use tokio::{ - sync::{Mutex, oneshot}, - time::timeout, -}; +use std::time::Duration; use tokio_tungstenite::tungstenite::Message; -use universalpubsub::NextOutput; + +use crate::shared_state::{SharedState, TunnelMessageData}; + +pub mod shared_state; const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2); pub struct PegboardGateway { ctx: StandaloneCtx, - request_counter: AtomicU64, + shared_state: SharedState, actor_id: Id, - runner_id: Id, + runner_key: String, port_name: String, } impl PegboardGateway { - pub fn new(ctx: StandaloneCtx, actor_id: Id, runner_id: Id, port_name: String) -> Self { + pub fn new( + ctx: StandaloneCtx, + shared_state: SharedState, + actor_id: Id, + runner_key: String, + port_name: String, + ) -> Self { Self { ctx, - request_counter: AtomicU64::new(0), + shared_state, actor_id, - runner_id, + runner_key, port_name, } } @@ -70,7 +62,7 @@ impl CustomServeTrait for PegboardGateway { match res { Result::Ok(x) => Ok(x), Err(err) => { - if is_tunnel_closed_error(&err) { + if is_tunnel_service_unavailable(&err) { // This will force the request to be retried with a new tunnel Ok(Response::builder() .status(StatusCode::SERVICE_UNAVAILABLE) @@ -96,7 +88,7 @@ impl CustomServeTrait for PegboardGateway { { Result::Ok(()) => std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()), Result::Err((client_ws, err)) => { - if is_tunnel_closed_error(&err) { + if is_tunnel_service_unavailable(&err) { Err(( client_ws, rivet_guard_core::errors::WebSocketServiceUnavailable.build(), @@ -124,13 +116,10 @@ impl PegboardGateway { .map_err(|_| anyhow!("invalid x-rivet-actor header"))? .to_string(); - // Generate request ID using atomic counter - let request_id = self.request_counter.fetch_add(1, Ordering::SeqCst); - // Extract request parts let mut headers = HashableMap::new(); for (name, value) in req.headers() { - if let ResultOk(value_str) = value.to_str() { + if let Result::Ok(value_str) = value.to_str() { headers.insert(name.to_string(), value_str.to_string()); } } @@ -149,9 +138,21 @@ impl PegboardGateway { .map_err(|e| anyhow!("failed to read body: {}", e))? .to_bytes(); - // Create tunnel message - let request_start = ToServerRequestStart { - request_id, + // Build subject to publish to + let tunnel_subject = pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new( + &self.runner_key, + &self.port_name, + ) + .to_string(); + + // Start listening for request responses + let (request_id, mut msg_rx) = self + .shared_state + .start_in_flight_request(tunnel_subject) + .await; + + // Start request + let message = MessageKind::ToServerRequestStart(ToServerRequestStart { actor_id: actor_id.clone(), method, path, @@ -162,117 +163,33 @@ impl PegboardGateway { Some(body_bytes.to_vec()) }, stream: false, - }; - - let message = TunnelMessage { - body: MessageBody::ToServerRequestStart(request_start), - }; + }); + self.shared_state.send_message(request_id, message).await?; + + // Wait for response + tracing::info!("starting response handler task"); + let response_start = loop { + let Some(msg) = msg_rx.recv().await else { + tracing::warn!("received no message response"); + return Err(RequestError::ServiceUnavailable.into()); + }; - // Serialize message - let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message)) - .map_err(|e| anyhow!("failed to serialize message: {}", e))?; - - // Build pubsub topic - let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); - let topic = tunnel_subject.to_string(); - - tracing::info!( - %topic, - ?self.runner_id, - %self.port_name, - ?request_id, - "publishing request to pubsub" - ); - - // Create response channel - let (response_tx, response_rx) = oneshot::channel(); - let response_map = Arc::new(Mutex::new(HashMap::new())); - response_map.lock().await.insert(request_id, response_tx); - - // Subscribe to response topic - let response_subject = - TunnelHttpResponseSubject::new(self.runner_id, &self.port_name, request_id); - let response_topic = response_subject.to_string(); - - tracing::info!( - ?response_topic, - ?request_id, - "subscribing to response topic" - ); - - let mut subscriber = self.ctx.ups()?.subscribe(&response_topic).await?; - - // Spawn task to handle response - let response_map_clone = response_map.clone(); - tokio::spawn(async move { - tracing::info!("starting response handler task"); - while let ResultOk(NextOutput::Message(msg)) = subscriber.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} - Err(err) => { - tracing::warn!(?err, "failed to ack gateway request response message") + match msg { + TunnelMessageData::Message(msg) => match msg { + MessageKind::ToClientResponseStart(response_start) => { + break response_start; } - }; - - tracing::info!( - payload_len = msg.payload.len(), - "received response from pubsub" - ); - if let ResultOk(tunnel_msg) = versioned::TunnelMessage::deserialize(&msg.payload) { - match tunnel_msg.body { - MessageBody::ToClientResponseStart(response_start) => { - tracing::info!(request_id = ?response_start.request_id, status = response_start.status, "received response from tunnel"); - if let Some(tx) = response_map_clone - .lock() - .await - .remove(&response_start.request_id) - { - tracing::info!(request_id = ?response_start.request_id, "sending response to handler"); - let _ = tx.send(response_start); - } else { - tracing::warn!(request_id = ?response_start.request_id, "no handler found for response"); - } - } - _ => { - tracing::warn!("received non-response message from pubsub"); - } + _ => { + tracing::warn!("received non-response message from pubsub"); } - } else { - tracing::error!("failed to deserialize response from pubsub"); + }, + TunnelMessageData::Timeout => { + tracing::warn!("tunnel message timeout"); + return Err(RequestError::ServiceUnavailable.into()); } } - tracing::info!("response handler task ended"); - }); - - // Publish request - self.ctx - .ups()? - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) - .await - .map_err(|e| anyhow!("failed to publish request: {}", e))?; - - // Send finish message - let finish_message = TunnelMessage { - body: MessageBody::ToServerRequestFinish(ToServerRequestFinish { - request_id, - reason: StreamFinishReason::Complete, - }), - }; - let finish_serialized = - versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(finish_message)) - .map_err(|e| anyhow!("failed to serialize finish message: {}", e))?; - self.ctx - .ups()? - .request_with_timeout(&topic, &finish_serialized, UPS_REQ_TIMEOUT) - .await - .map_err(|e| anyhow!("failed to publish finish message: {}", e))?; - - // Wait for response with timeout - let response_start = match timeout(Duration::from_secs(30), response_rx).await { - ResultOk(ResultOk(response)) => response, - _ => return Err(anyhow!("request timed out")), }; + tracing::info!("response handler task ended"); // Build HTTP response let mut response_builder = @@ -309,62 +226,69 @@ impl PegboardGateway { Err(err) => return Err((client_ws, err)), }; - // Generate WebSocket ID using atomic counter - let websocket_id = self.request_counter.fetch_add(1, Ordering::SeqCst); - // Extract headers let mut request_headers = HashableMap::new(); for (name, value) in headers { - if let ResultOk(value_str) = value.to_str() { + if let Result::Ok(value_str) = value.to_str() { request_headers.insert(name.to_string(), value_str.to_string()); } } - let ups = match self.ctx.ups() { - Result::Ok(u) => u, - Err(err) => return Err((client_ws, err.into())), - }; - - // Subscribe to messages from server before informing server that a client websocket is connecting to - // prevent race conditions. - let ws_subject = - TunnelHttpWebSocketSubject::new(self.runner_id, &self.port_name, websocket_id); - let response_topic = ws_subject.to_string(); - let mut subscriber = match ups.subscribe(&response_topic).await { - Result::Ok(sub) => sub, - Err(err) => return Err((client_ws, err.into())), - }; + // Build subject to publish to + let tunnel_subject = pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new( + &self.runner_key, + &self.port_name, + ) + .to_string(); - // Build pubsub topic - let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); - let topic = tunnel_subject.to_string(); + // Start listening for WebSocket messages + let (request_id, mut msg_rx) = self + .shared_state + .start_in_flight_request(tunnel_subject.clone()) + .await; // Send WebSocket open message - let open_message = TunnelMessage { - body: MessageBody::ToServerWebSocketOpen(ToServerWebSocketOpen { - actor_id: actor_id.clone(), - web_socket_id: websocket_id, - path: path.to_string(), - headers: request_headers, - }), - }; - - let serialized = - match versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(open_message)) { - Result::Ok(s) => s, - Err(e) => { - return Err(( - client_ws, - anyhow!("failed to serialize websocket open: {}", e), - )); - } - }; + let open_message = MessageKind::ToServerWebSocketOpen(ToServerWebSocketOpen { + actor_id: actor_id.clone(), + path: path.to_string(), + headers: request_headers, + }); - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + if let Err(err) = self + .shared_state + .send_message(request_id, open_message) .await { - return Err((client_ws, err.into())); + return Err((client_ws, err)); + } + + // Wait for WebSocket open acknowledgment + let open_ack_received = loop { + let Some(msg) = msg_rx.recv().await else { + tracing::warn!("received no websocket open response"); + return Err((client_ws, RequestError::ServiceUnavailable.into())); + }; + + match msg { + TunnelMessageData::Message(MessageKind::ToClientWebSocketOpen) => { + break true; + } + TunnelMessageData::Message(MessageKind::ToClientWebSocketClose(close)) => { + tracing::info!(?close, "websocket closed before opening"); + return Err((client_ws, RequestError::ServiceUnavailable.into())); + } + TunnelMessageData::Timeout => { + tracing::warn!("websocket open timeout"); + return Err((client_ws, RequestError::ServiceUnavailable.into())); + } + _ => { + tracing::warn!("received unexpected message while waiting for websocket open"); + } + } + }; + + if !open_ack_received { + return Err((client_ws, anyhow!("failed to open websocket"))); } // Accept the WebSocket @@ -379,33 +303,30 @@ impl PegboardGateway { let (mut ws_sink, mut ws_stream) = ws_stream.split(); // Spawn task to forward messages from server to client + let mut msg_rx_for_task = msg_rx; tokio::spawn(async move { - while let ResultOk(NextOutput::Message(msg)) = subscriber.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} - Err(err) => { - tracing::warn!(?err, "failed to ack gateway websocket message") - } - }; - - if let ResultOk(tunnel_msg) = versioned::TunnelMessage::deserialize(&msg.payload) { - match tunnel_msg.body { - MessageBody::ToClientWebSocketMessage(ws_msg) => { - if ws_msg.web_socket_id == websocket_id { - let msg = if ws_msg.binary { - Message::Binary(ws_msg.data.into()) - } else { - Message::Text( - String::from_utf8_lossy(&ws_msg.data).into_owned().into(), - ) - }; - let _ = ws_sink.send(msg).await; - } + while let Some(msg) = msg_rx_for_task.recv().await { + match msg { + TunnelMessageData::Message(MessageKind::ToClientWebSocketMessage(ws_msg)) => { + let msg = if ws_msg.binary { + Message::Binary(ws_msg.data.into()) + } else { + Message::Text(String::from_utf8_lossy(&ws_msg.data).into_owned().into()) + }; + if let Err(e) = ws_sink.send(msg).await { + tracing::warn!(?e, "failed to send websocket message to client"); + break; } - MessageBody::ToClientWebSocketClose(_) => break, - _ => {} } + TunnelMessageData::Message(MessageKind::ToClientWebSocketClose(close)) => { + tracing::info!(?close, "server closed websocket"); + break; + } + TunnelMessageData::Timeout => { + tracing::warn!("websocket message timeout"); + break; + } + _ => {} } } }); @@ -414,25 +335,14 @@ impl PegboardGateway { let mut close_reason = None; while let Some(msg) = ws_stream.next().await { match msg { - ResultOk(Message::Binary(data)) => { - let ws_message = TunnelMessage { - body: MessageBody::ToServerWebSocketMessage(ToServerWebSocketMessage { - web_socket_id: websocket_id, + Result::Ok(Message::Binary(data)) => { + let ws_message = + MessageKind::ToServerWebSocketMessage(ToServerWebSocketMessage { data: data.into(), binary: true, - }), - }; - let serialized = match versioned::TunnelMessage::serialize( - versioned::TunnelMessage::V1(ws_message), - ) { - Result::Ok(s) => s, - Err(_) => break, - }; - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) - .await - { - if is_tunnel_closed_error(&err) { + }); + if let Err(err) = self.shared_state.send_message(request_id, ws_message).await { + if is_tunnel_service_unavailable(&err) { tracing::warn!("tunnel closed sending binary message"); close_reason = Some("Tunnel closed".to_string()); break; @@ -441,25 +351,14 @@ impl PegboardGateway { } } } - ResultOk(Message::Text(text)) => { - let ws_message = TunnelMessage { - body: MessageBody::ToServerWebSocketMessage(ToServerWebSocketMessage { - web_socket_id: websocket_id, + Result::Ok(Message::Text(text)) => { + let ws_message = + MessageKind::ToServerWebSocketMessage(ToServerWebSocketMessage { data: text.as_bytes().to_vec(), binary: false, - }), - }; - let serialized = match versioned::TunnelMessage::serialize( - versioned::TunnelMessage::V1(ws_message), - ) { - Result::Ok(s) => s, - Err(_) => break, - }; - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) - .await - { - if is_tunnel_closed_error(&err) { + }); + if let Err(err) = self.shared_state.send_message(request_id, ws_message).await { + if is_tunnel_service_unavailable(&err) { tracing::warn!("tunnel closed sending text message"); close_reason = Some("Tunnel closed".to_string()); break; @@ -468,32 +367,23 @@ impl PegboardGateway { } } } - ResultOk(Message::Close(_)) | Err(_) => break, + Result::Ok(Message::Close(_)) | Err(_) => break, _ => {} } } // Send WebSocket close message - let close_message = TunnelMessage { - body: MessageBody::ToServerWebSocketClose(ToServerWebSocketClose { - web_socket_id: websocket_id, - code: None, - reason: close_reason, - }), - }; - - let serialized = match versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1( - close_message, - )) { - Result::Ok(s) => s, - Err(_) => Vec::new(), - }; + let close_message = MessageKind::ToServerWebSocketClose(ToServerWebSocketClose { + code: None, + reason: close_reason, + }); - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + if let Err(err) = self + .shared_state + .send_message(request_id, close_message) .await { - if is_tunnel_closed_error(&err) { + if is_tunnel_service_unavailable(&err) { tracing::warn!("tunnel closed sending close message"); } else { tracing::error!(?err, "error sending close message"); @@ -504,14 +394,13 @@ impl PegboardGateway { } } +#[derive(thiserror::Error, Debug)] +enum RequestError { + #[error("service unavailable")] + ServiceUnavailable, +} + /// Determines if the tunnel is closed by if the UPS service is no longer responding. -fn is_tunnel_closed_error(err: &anyhow::Error) -> bool { - if let Some(err) = err.chain().find_map(|x| x.downcast_ref::()) - && err.group() == "ups" - && err.code() == "request_timeout" - { - true - } else { - false - } +fn is_tunnel_service_unavailable(err: &anyhow::Error) -> bool { + err.chain().any(|x| x.is::()) } diff --git a/packages/core/pegboard-gateway/src/shared_state.rs b/packages/core/pegboard-gateway/src/shared_state.rs new file mode 100644 index 0000000000..75b485679b --- /dev/null +++ b/packages/core/pegboard-gateway/src/shared_state.rs @@ -0,0 +1,286 @@ +use anyhow::*; +use gas::prelude::*; +use rivet_tunnel_protocol::{ + MessageId, MessageKind, PROTOCOL_VERSION, PubSubMessage, RequestId, versioned, +}; +use std::{ + collections::HashMap, + ops::Deref, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{Mutex, mpsc}; +use universalpubsub::{NextOutput, PubSub, PublishOpts, Subscriber}; +use versioned_data_util::OwnedVersionedData as _; + +const GC_INTERVAL: Duration = Duration::from_secs(60); +const MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(5); + +struct InFlightRequest { + /// UPS subject to send messages to for this request. + receiver_subject: String, + /// Sender for incoming messages to this request. + msg_tx: mpsc::Sender, + /// True once first message for this request has been sent (so runner learned reply_to). + opened: bool, +} + +struct PendingMessage { + request_id: RequestId, + send_instant: Instant, +} + +pub enum TunnelMessageData { + Message(MessageKind), + Timeout, +} + +pub struct SharedStateInner { + ups: PubSub, + receiver_subject: String, + requests_in_flight: Mutex>, + pending_messages: Mutex>, +} + +#[derive(Clone)] +pub struct SharedState(Arc); + +impl SharedState { + pub fn new(ups: PubSub) -> Self { + let gateway_id = Uuid::new_v4(); + let receiver_subject = + pegboard::pubsub_subjects::TunnelGatewayReceiverSubject::new(gateway_id).to_string(); + + Self(Arc::new(SharedStateInner { + ups, + receiver_subject, + requests_in_flight: Mutex::new(HashMap::new()), + pending_messages: Mutex::new(HashMap::new()), + })) + } + + pub async fn start(&self) -> Result<()> { + let sub = self.ups.subscribe(&self.receiver_subject).await?; + + let self_clone = self.clone(); + tokio::spawn(async move { self_clone.receiver(sub).await }); + + let self_clone = self.clone(); + tokio::spawn(async move { self_clone.gc().await }); + + Ok(()) + } + + pub async fn send_message( + &self, + request_id: RequestId, + message_kind: MessageKind, + ) -> Result<()> { + let message_id = Uuid::new_v4().as_bytes().clone(); + + // Get subject and whether this is the first message for this request + let (tunnel_receiver_subject, include_reply_to) = { + let mut requests_in_flight = self.requests_in_flight.lock().await; + if let Some(req) = requests_in_flight.get_mut(&request_id) { + let receiver_subject = req.receiver_subject.clone(); + let include_reply_to = !req.opened; + if include_reply_to { + // Mark as opened so subsequent messages skip reply_to + req.opened = true; + } + (receiver_subject, include_reply_to) + } else { + bail!("request not in flight") + } + }; + + // Save pending message + { + let mut pending_messages = self.pending_messages.lock().await; + pending_messages.insert( + message_id, + PendingMessage { + request_id, + send_instant: Instant::now(), + }, + ); + } + + // Send message + let message = PubSubMessage { + request_id, + message_id, + // Only send reply to subject on the first message for this request. This reduces + // overhead of subsequent messages. + reply_to: if include_reply_to { + Some(self.receiver_subject.clone()) + } else { + None + }, + message_kind, + }; + let message_serialized = versioned::PubSubMessage::latest(message) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + self.ups + .publish( + &tunnel_receiver_subject, + &message_serialized, + PublishOpts::one(), + ) + .await?; + + Ok(()) + } + + pub async fn start_in_flight_request( + &self, + receiver_subject: String, + ) -> (RequestId, mpsc::Receiver) { + let id = Uuid::new_v4().into_bytes(); + let (msg_tx, msg_rx) = mpsc::channel(128); + self.requests_in_flight.lock().await.insert( + id, + InFlightRequest { + receiver_subject, + msg_tx, + opened: false, + }, + ); + (id, msg_rx) + } + + async fn receiver(&self, mut sub: Subscriber) { + while let Result::Ok(NextOutput::Message(msg)) = sub.next().await { + tracing::info!( + payload_len = msg.payload.len(), + "received message from pubsub" + ); + + match versioned::PubSubMessage::deserialize_with_embedded_version(&msg.payload) { + Result::Ok(msg) => { + tracing::debug!( + ?msg.request_id, + ?msg.message_id, + "successfully deserialized message" + ); + if let MessageKind::Ack = &msg.message_kind { + // Handle ack message + + let mut pending_messages = self.pending_messages.lock().await; + if pending_messages.remove(&msg.message_id).is_none() { + tracing::warn!( + "pending message does not exist or ack received after message body" + ); + } + } else { + // Forward message to receiver + + // Send message to sender using request_id directly + let requests_in_flight = self.requests_in_flight.lock().await; + let Some(in_flight) = requests_in_flight.get(&msg.request_id) else { + tracing::debug!( + ?msg.request_id, + "in flight has already been disconnected" + ); + continue; + }; + tracing::debug!( + ?msg.request_id, + "forwarding message to request handler" + ); + let _ = in_flight + .msg_tx + .send(TunnelMessageData::Message(msg.message_kind)) + .await; + + // Send ack + let ups_clone = self.ups.clone(); + let receiver_subject = in_flight.receiver_subject.clone(); + let ack_message = PubSubMessage { + request_id: msg.request_id, + message_id: Uuid::new_v4().into_bytes(), + reply_to: None, + message_kind: MessageKind::Ack, + }; + let ack_message_serialized = + match versioned::PubSubMessage::latest(ack_message) + .serialize_with_embedded_version(PROTOCOL_VERSION) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize ack"); + continue; + } + }; + tokio::spawn(async move { + if let Result::Err(err) = ups_clone + .publish( + &receiver_subject, + &ack_message_serialized, + PublishOpts::one(), + ) + .await + { + tracing::warn!(?err, "failed to ack message") + } + }); + } + } + Result::Err(err) => { + tracing::error!(?err, "failed to parse message"); + } + } + } + } + + async fn gc(&self) { + let mut interval = tokio::time::interval(GC_INTERVAL); + loop { + interval.tick().await; + + let now = Instant::now(); + + // Purge unacked messages + { + let mut pending_messages = self.pending_messages.lock().await; + let mut removed_req_ids = Vec::new(); + pending_messages.retain(|_k, v| { + if now.duration_since(v.send_instant) > MESSAGE_ACK_TIMEOUT { + // Expired + removed_req_ids.push(v.request_id.clone()); + false + } else { + true + } + }); + + // Close in-flight messages + let requests_in_flight = self.requests_in_flight.lock().await; + for req_id in removed_req_ids { + if let Some(x) = requests_in_flight.get(&req_id) { + let _ = x.msg_tx.send(TunnelMessageData::Timeout); + } else { + tracing::warn!( + ?req_id, + "message expired for in flight that does not exist" + ); + } + } + } + + // Purge no longer in flight + { + let mut requests_in_flight = self.requests_in_flight.lock().await; + requests_in_flight.retain(|_k, v| !v.msg_tx.is_closed()); + } + } + } +} + +impl Deref for SharedState { + type Target = SharedStateInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/packages/core/pegboard-runner-ws/src/lib.rs b/packages/core/pegboard-runner-ws/src/lib.rs index 32de54f51d..576c7d567e 100644 --- a/packages/core/pegboard-runner-ws/src/lib.rs +++ b/packages/core/pegboard-runner-ws/src/lib.rs @@ -270,6 +270,7 @@ async fn build_connection( UrlData { protocol_version, namespace, + runner_key, }: UrlData, ) -> Result<(Id, Arc)> { let namespace = ctx @@ -300,9 +301,7 @@ async fn build_connection( .map_err(|err: anyhow::Error| WsError::InvalidPacket(err.to_string()).build())?; let (runner_id, workflow_id) = if let protocol::ToServer::Init { - runner_id, name, - key, version, total_slots, addresses_http, @@ -311,7 +310,16 @@ async fn build_connection( .. } = &packet { - let runner_id = if let Some(runner_id) = runner_id { + // Look up existing runner by key + let existing_runner = ctx + .op(pegboard::ops::runner::get_by_key::Input { + namespace_id: namespace.namespace_id, + name: name.clone(), + key: runner_key.clone(), + }) + .await?; + + let runner_id = if let Some(runner) = existing_runner.runner { // IMPORTANT: Before we spawn/get the workflow, we try to update the runner's last ping ts. // This ensures if the workflow is currently checking for expiry that it will not expire // (because we are about to send signals to it) and if it is already expired (but not @@ -319,7 +327,7 @@ async fn build_connection( let update_ping_res = ctx .op(pegboard::ops::runner::update_alloc_idx::Input { runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { - runner_id: *runner_id, + runner_id: runner.runner_id, action: Action::UpdatePing { rtt: 0 }, }], }) @@ -332,11 +340,14 @@ async fn build_connection( .map(|notif| matches!(notif.eligibility, RunnerEligibility::Expired)) .unwrap_or_default() { + // Runner expired, create a new one Id::new_v1(ctx.config().dc_label()) } else { - *runner_id + // Use existing runner + runner.runner_id } } else { + // No existing runner for this key, create a new one Id::new_v1(ctx.config().dc_label()) }; @@ -346,7 +357,7 @@ async fn build_connection( runner_id, namespace_id: namespace.namespace_id, name: name.clone(), - key: key.clone(), + key: runner_key.clone(), version: version.clone(), total_slots: *total_slots, @@ -790,19 +801,17 @@ async fn msg_thread_inner(ctx: &StandaloneCtx, conns: Arc>) struct UrlData { protocol_version: u16, namespace: String, + runner_key: String, } fn parse_url(addr: SocketAddr, uri: hyper::Uri) -> Result { let url = url::Url::parse(&format!("ws://{addr}{uri}"))?; - // Get protocol version from last path segment - let last_segment = url - .path_segments() - .context("invalid url")? - .last() - .context("no path segments")?; - ensure!(last_segment.starts_with('v'), "invalid protocol version"); - let protocol_version = last_segment[1..] + // Read protocol version from query parameters (required) + let protocol_version = url + .query_pairs() + .find_map(|(n, v)| (n == "protocol_version").then_some(v)) + .context("missing `protocol_version` query parameter")? .parse::() .context("invalid protocol version")?; @@ -813,9 +822,17 @@ fn parse_url(addr: SocketAddr, uri: hyper::Uri) -> Result { .context("missing `namespace` query parameter")? .to_string(); + // Read runner key from query parameters (required) + let runner_key = url + .query_pairs() + .find_map(|(n, v)| (n == "runner_key").then_some(v)) + .context("missing `runner_key` query parameter")? + .to_string(); + Ok(UrlData { protocol_version, namespace, + runner_key, }) } diff --git a/packages/core/pegboard-tunnel/Cargo.toml b/packages/core/pegboard-tunnel/Cargo.toml index ad8747ffd8..ef927c3559 100644 --- a/packages/core/pegboard-tunnel/Cargo.toml +++ b/packages/core/pegboard-tunnel/Cargo.toml @@ -8,6 +8,7 @@ edition.workspace = true [dependencies] anyhow.workspace = true async-trait.workspace = true +bytes.workspace = true futures.workspace = true gas.workspace = true http-body-util = "0.1" @@ -37,4 +38,4 @@ versioned-data-util.workspace = true rand.workspace = true rivet-cache.workspace = true rivet-pools.workspace = true -rivet-util.workspace = true \ No newline at end of file +rivet-util.workspace = true diff --git a/packages/core/pegboard-tunnel/src/lib.rs b/packages/core/pegboard-tunnel/src/lib.rs index 37ddc97cb8..a71d3dc62f 100644 --- a/packages/core/pegboard-tunnel/src/lib.rs +++ b/packages/core/pegboard-tunnel/src/lib.rs @@ -1,54 +1,29 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - use anyhow::*; use async_trait::async_trait; +use bytes::Bytes; use futures::{SinkExt, StreamExt}; use gas::prelude::*; use http_body_util::Full; -use hyper::body::{Bytes, Incoming as BodyIncoming}; -use hyper::{Request, Response, StatusCode}; -use hyper_tungstenite::tungstenite::Utf8Bytes as WsUtf8Bytes; -use hyper_tungstenite::tungstenite::protocol::frame::CloseFrame as WsCloseFrame; -use hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode as WsCloseCode; +use hyper::{Response, StatusCode}; use hyper_tungstenite::{HyperWebsocket, tungstenite::Message as WsMessage}; -use pegboard::pubsub_subjects::{ - TunnelHttpResponseSubject, TunnelHttpRunnerSubject, TunnelHttpWebSocketSubject, +use rivet_guard_core::{ + custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +}; +use rivet_tunnel_protocol::{ + MessageKind, PROTOCOL_VERSION, PubSubMessage, RequestId, RunnerMessage, versioned, }; -use rivet_guard_core::custom_serve::CustomServeTrait; -use rivet_guard_core::proxy_service::ResponseBody; -use rivet_guard_core::request_context::RequestContext; -use rivet_pools::Pools; -use rivet_tunnel_protocol::{MessageBody, TunnelMessage, versioned}; -use rivet_util::Id; -use std::net::SocketAddr; -use tokio::net::TcpListener; +use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; -use tokio::sync::RwLock; -use tokio_tungstenite::accept_async; -use tracing::{error, info}; -use universalpubsub::pubsub::NextOutput; - -const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2); - -struct RunnerConnection { - _runner_id: Id, - _port_name: String, -} - -type Connections = Arc>>>; +use universalpubsub::{PublishOpts, pubsub::NextOutput}; +use versioned_data_util::OwnedVersionedData as _; pub struct PegboardTunnelCustomServe { ctx: StandaloneCtx, - connections: Connections, } impl PegboardTunnelCustomServe { - pub async fn new(ctx: StandaloneCtx) -> Result { - let connections = Arc::new(RwLock::new(HashMap::new())); - - Ok(Self { ctx, connections }) + pub fn new(ctx: StandaloneCtx) -> Self { + Self { ctx } } } @@ -82,654 +57,258 @@ impl CustomServeTrait for PegboardTunnelCustomServe { Result::Ok(u) => u, Err(e) => return Err((client_ws, e.into())), }; - let connections = self.connections.clone(); - // Extract runner_id from query parameters - let runner_id = if let std::result::Result::Ok(url) = - url::Url::parse(&format!("ws://placeholder/{path}")) + // Parse URL to extract runner_id and protocol version + let url = match url::Url::parse(&format!("ws://placeholder/{path}")) { + Result::Ok(u) => u, + Err(e) => return Err((client_ws, e.into())), + }; + + // Extract runner_key from query parameters (required) + let runner_key = match url + .query_pairs() + .find_map(|(n, v)| (n == "runner_key").then_some(v)) + { + Some(key) => key.to_string(), + None => { + return Err((client_ws, anyhow!("runner_key query parameter is required"))); + } + }; + + // Extract protocol version from query parameters (required) + let protocol_version = match url + .query_pairs() + .find_map(|(n, v)| (n == "protocol_version").then_some(v)) + .as_ref() + .and_then(|v| v.parse::().ok()) { - url.query_pairs() - .find_map(|(n, v)| (n == "runner_id").then_some(v)) - .as_ref() - .and_then(|id| Id::parse(id).ok()) - .unwrap_or(Id::nil()) - } else { - Id::nil() + Some(version) => version, + None => { + return Err(( + client_ws, + anyhow!("protocol_version query parameter is required and must be a valid u16"), + )); + } }; let port_name = "main".to_string(); // Use "main" as default port name - info!( - ?runner_id, + tracing::info!( + ?runner_key, ?port_name, + ?protocol_version, ?path, "tunnel WebSocket connection established" ); - let connection_id = Id::nil(); - // Subscribe to pubsub topic for this runner before accepting the client websocket so // that failures can be retried by the proxy. - let topic = TunnelHttpRunnerSubject::new(runner_id, &port_name).to_string(); - info!(%topic, ?runner_id, "subscribing to pubsub topic"); - + let topic = + pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new(&runner_key, &port_name) + .to_string(); + tracing::info!(%topic, ?runner_key, "subscribing to runner receiver topic"); let mut sub = match ups.subscribe(&topic).await { Result::Ok(s) => s, Err(e) => return Err((client_ws, e.into())), }; + // Accept WS let ws_stream = match client_ws.await { Result::Ok(ws) => ws, Err(e) => { // Handshake already in progress; cannot retry safely here - error!(error=?e, "client websocket await failed"); + tracing::error!(error=?e, "client websocket await failed"); return std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()); } }; - // Split WebSocket stream into read and write halves let (ws_write, mut ws_read) = ws_stream.split(); let ws_write = Arc::new(tokio::sync::Mutex::new(ws_write)); - // Store connection - let connection = Arc::new(RunnerConnection { - _runner_id: runner_id, - _port_name: port_name.clone(), - }); + struct ActiveRequest { + /// Subject to send replies to. + reply_to: String, + } - connections - .write() - .await - .insert(connection_id, connection.clone()); + // Active HTTP & WebSocket requests. They are separate but use the same mechanism to + // maintain state. + let active_requests = Arc::new(Mutex::new(HashMap::::new())); - // Handle bidirectional message forwarding + // Forward pubsub -> WebSocket let ws_write_pubsub_to_ws = ws_write.clone(); - let connections_clone = connections.clone(); let ups_clone = ups.clone(); - - // Task for forwarding pubsub -> WebSocket + let active_requests_clone = active_requests.clone(); let pubsub_to_ws = tokio::spawn(async move { - info!("starting pubsub to WebSocket forwarding task"); - while let ::std::result::Result::Ok(NextOutput::Message(msg)) = sub.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} + while let Result::Ok(NextOutput::Message(ups_msg)) = sub.next().await { + tracing::info!( + payload_len = ups_msg.payload.len(), + "received message from pubsub, forwarding to WebSocket" + ); + + // Parse message + let msg = match versioned::PubSubMessage::deserialize_with_embedded_version( + &ups_msg.payload, + ) { + Result::Ok(x) => x, Err(err) => { - tracing::warn!(?err, "failed to ack gateway request response message") + tracing::error!(?err, "failed to parse tunnel message"); + continue; } }; - info!( - payload_len = msg.payload.len(), - "received message from pubsub, forwarding to WebSocket" - ); + // Save active request + if let Some(reply_to) = msg.reply_to { + let mut active_requests = active_requests_clone.lock().await; + active_requests.insert(msg.request_id, ActiveRequest { reply_to }); + } + + // If terminal, remove active request tracking + if is_message_kind_request_close(&msg.message_kind) { + let mut active_requests = active_requests_clone.lock().await; + active_requests.remove(&msg.request_id); + } + // Forward raw message to WebSocket - let ws_msg = WsMessage::Binary(msg.payload.to_vec().into()); + let tunnel_msg = match versioned::RunnerMessage::latest(RunnerMessage { + request_id: msg.request_id, + message_id: msg.message_id, + message_kind: msg.message_kind, + }) + .serialize_version(protocol_version) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize tunnel message"); + continue; + } + }; + let ws_msg = WsMessage::Binary(tunnel_msg.into()); { let mut stream = ws_write_pubsub_to_ws.lock().await; if let Err(e) = stream.send(ws_msg).await { - error!(?e, "failed to send message to WebSocket"); + tracing::error!(?e, "failed to send message to WebSocket"); break; } } } - info!("pubsub to WebSocket forwarding task ended"); + tracing::info!("pubsub to WebSocket forwarding task ended"); }); - // Task for forwarding WebSocket -> pubsub - let ws_write_ws_to_pubsub = ws_write.clone(); + // Forward WebSocket -> pubsub + let active_requests_clone = active_requests.clone(); + let runner_key_clone = runner_key.clone(); let ws_to_pubsub = tokio::spawn(async move { - info!("starting WebSocket to pubsub forwarding task"); + tracing::info!("starting WebSocket to pubsub forwarding task"); while let Some(msg) = ws_read.next().await { match msg { - ::std::result::Result::Ok(WsMessage::Binary(data)) => { - info!( + Result::Ok(WsMessage::Binary(data)) => { + tracing::info!( data_len = data.len(), "received binary message from WebSocket" ); - // Parse the tunnel message to extract request_id - match versioned::TunnelMessage::deserialize(&data) { - ::std::result::Result::Ok(tunnel_msg) => { - // Handle different message types - match &tunnel_msg.body { - MessageBody::ToClientResponseStart(resp) => { - info!(?resp.request_id, status = resp.status, "forwarding HTTP response to pubsub"); - let response_topic = TunnelHttpResponseSubject::new( - runner_id, - &port_name, - resp.request_id, - ) - .to_string(); - - info!(%response_topic, ?resp.request_id, "publishing HTTP response to pubsub"); - if let Err(e) = ups_clone - .request_with_timeout( - &response_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing HTTP response; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?resp.request_id, "failed to publish HTTP response to pubsub"); - } - } else { - info!(?resp.request_id, "successfully published HTTP response to pubsub"); - } - } - MessageBody::ToClientWebSocketMessage(ws_msg) => { - info!(?ws_msg.web_socket_id, "forwarding WebSocket message to pubsub"); - // Forward WebSocket messages to the topic that pegboard-gateway subscribes to - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_msg.web_socket_id, - ) - .to_string(); - - info!(%ws_topic, ?ws_msg.web_socket_id, "publishing WebSocket message to pubsub"); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket message; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to pubsub"); - } - } else { - info!(?ws_msg.web_socket_id, "successfully published WebSocket message to pubsub"); - } - } - MessageBody::ToClientWebSocketOpen(ws_open) => { - info!(?ws_open.web_socket_id, "forwarding WebSocket open to pubsub"); - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_open.web_socket_id, - ) - .to_string(); + // Parse message + let msg = match versioned::RunnerMessage::deserialize_version( + &data, + protocol_version, + ) + .and_then(|x| x.into_latest()) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to deserialize message"); + continue; + } + }; + + // Determine reply to subject + let request_id = msg.request_id; + let reply_to = { + let active_requests = active_requests_clone.lock().await; + if let Some(req) = active_requests.get(&request_id) { + req.reply_to.clone() + } else { + tracing::warn!( + "no active request for tunnel message, may have timed out" + ); + continue; + } + }; - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket open; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to pubsub"); - } - } else { - info!(?ws_open.web_socket_id, "successfully published WebSocket open to pubsub"); - } - } - MessageBody::ToClientWebSocketClose(ws_close) => { - info!(?ws_close.web_socket_id, "forwarding WebSocket close to pubsub"); - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_close.web_socket_id, - ) - .to_string(); + // Remove active request entries when terminal + if is_message_kind_request_close(&msg.message_kind) { + let mut active_requests = active_requests_clone.lock().await; + active_requests.remove(&request_id); + } - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket close; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to pubsub"); - } - } else { - info!(?ws_close.web_socket_id, "successfully published WebSocket close to pubsub"); - } - } - _ => { - // For other message types, we might not need to forward to pubsub - info!( - "Received non-response message from WebSocket, skipping pubsub forward" - ); - continue; - } + // Publish message to UPS + let message_serialized = + match versioned::PubSubMessage::latest(PubSubMessage { + request_id: msg.request_id, + message_id: msg.message_id, + reply_to: None, + message_kind: msg.message_kind, + }) + .serialize_with_embedded_version(PROTOCOL_VERSION) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize tunnel to gateway"); + continue; } - } - ::std::result::Result::Err(e) => { - error!(?e, "failed to deserialize tunnel message from WebSocket"); + }; + match ups_clone + .publish(&reply_to, &message_serialized, PublishOpts::one()) + .await + { + Result::Ok(_) => {} + Err(err) => { + tracing::error!(?err, "error publishing ups message"); } } } - ::std::result::Result::Ok(WsMessage::Close(_)) => { - info!(?runner_id, "WebSocket closed"); + Result::Ok(WsMessage::Close(_)) => { + tracing::info!(?runner_key_clone, "WebSocket closed"); break; } - ::std::result::Result::Ok(_) => { + Result::Ok(_) => { // Ignore other message types } Err(e) => { - error!(?e, "WebSocket error"); + tracing::error!(?e, "WebSocket error"); break; } } } - info!("WebSocket to pubsub forwarding task ended"); - - // Clean up connection - connections_clone.write().await.remove(&connection_id); + tracing::info!("WebSocket to pubsub forwarding task ended"); }); // Wait for either task to complete tokio::select! { _ = pubsub_to_ws => { - info!("pubsub to WebSocket task completed"); + tracing::info!("pubsub to WebSocket task completed"); } _ = ws_to_pubsub => { - info!("WebSocket to pubsub task completed"); + tracing::info!("WebSocket to pubsub task completed"); } } // Clean up - connections.write().await.remove(&connection_id); - info!(?runner_id, "connection closed"); + tracing::info!(?runner_key, "connection closed"); std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()) } } -// Keep the old start function for backward compatibility in tests -pub async fn start(config: rivet_config::Config, pools: Pools) -> Result<()> { - let cache = rivet_cache::CacheInner::from_env(&config, pools.clone())?; - let ctx = StandaloneCtx::new( - gas::db::DatabaseKv::from_pools(pools.clone()).await?, - config.clone(), - pools.clone(), - cache, - "pegboard-tunnel", - Id::new_v1(config.dc_label()), - Id::new_v1(config.dc_label()), - )?; - - main_loop(ctx).await -} - -async fn main_loop(ctx: gas::prelude::StandaloneCtx) -> Result<()> { - let connections: Connections = Arc::new(RwLock::new(HashMap::new())); - - // Start WebSocket server - // Use pegboard config since pegboard_tunnel doesn't exist - let server_addr = SocketAddr::new( - ctx.config().pegboard().host(), - ctx.config().pegboard().port(), - ); - - info!(?server_addr, "starting pegboard-tunnel"); - - let listener = TcpListener::bind(&server_addr).await?; - - // Accept connections - loop { - let (tcp_stream, addr) = listener.accept().await?; - let connections = connections.clone(); - let ctx = ctx.clone(); - - tokio::spawn(async move { - if let Err(e) = handle_connection(ctx, tcp_stream, addr, connections).await { - error!(?e, ?addr, "connection handler error"); - } - }); +fn is_message_kind_request_close(kind: &MessageKind) -> bool { + match kind { + // HTTP terminal states + MessageKind::ToClientResponseStart(resp) => !resp.stream, + MessageKind::ToClientResponseChunk(chunk) => chunk.finish, + MessageKind::ToClientResponseAbort => true, + // WebSocket terminal states (either side closes) + MessageKind::ToClientWebSocketClose(_) => true, + MessageKind::ToServerWebSocketClose(_) => true, + _ => false, } } - -async fn handle_connection( - ctx: gas::prelude::StandaloneCtx, - tcp_stream: tokio::net::TcpStream, - addr: std::net::SocketAddr, - connections: Connections, -) -> Result<()> { - info!(?addr, "new connection"); - - // Parse WebSocket upgrade request - let ws_stream = accept_async(tcp_stream).await?; - - // For now, we'll expect the runner to send an initial message with its ID - // In production, this would be parsed from the URL path or headers - let runner_id = rivet_util::Id::nil(); // Placeholder - should be extracted from connection - let port_name = "default".to_string(); // Placeholder - should be extracted - - let connection_id = rivet_util::Id::nil(); - - // Subscribe to pubsub topic for this runner using raw pubsub client - let topic = TunnelHttpRunnerSubject::new(runner_id, &port_name).to_string(); - info!(%topic, ?runner_id, "subscribing to pubsub topic"); - - // Get UPS (UniversalPubSub) client - let ups = ctx.pools().ups()?; - let mut sub = ups.subscribe(&topic).await?; - - // Split WebSocket stream into read and write halves - let (ws_write, mut ws_read) = ws_stream.split(); - let ws_write = Arc::new(Mutex::new(ws_write)); - - // Store connection - let connection = Arc::new(RunnerConnection { - _runner_id: runner_id, - _port_name: port_name.clone(), - }); - - connections - .write() - .await - .insert(connection_id, connection.clone()); - - // Handle bidirectional message forwarding - let ws_write_clone = ws_write.clone(); - let connections_clone = connections.clone(); - let ups_clone = ups.clone(); - - // Task for forwarding pubsub -> WebSocket - let pubsub_to_ws = tokio::spawn(async move { - while let ::std::result::Result::Ok(NextOutput::Message(msg)) = sub.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} - Err(err) => { - tracing::warn!(?err, "failed to ack gateway request response message") - } - }; - - // Forward raw message to WebSocket - let ws_msg = - tokio_tungstenite::tungstenite::Message::Binary(msg.payload.to_vec().into()); - { - let mut stream = ws_write_clone.lock().await; - if let Err(e) = stream.send(ws_msg).await { - error!(?e, "failed to send message to WebSocket"); - break; - } - } - } - }); - - // Task for forwarding WebSocket -> pubsub - let ws_write_ws_to_pubsub = ws_write.clone(); - let ws_to_pubsub = tokio::spawn(async move { - while let Some(msg) = ws_read.next().await { - match msg { - ::std::result::Result::Ok(tokio_tungstenite::tungstenite::Message::Binary( - data, - )) => { - // Parse the tunnel message to extract request_id - match versioned::TunnelMessage::deserialize(&data) { - ::std::result::Result::Ok(tunnel_msg) => { - // Handle different message types - match &tunnel_msg.body { - MessageBody::ToClientResponseStart(resp) => { - let response_topic = TunnelHttpResponseSubject::new( - runner_id, - &port_name, - resp.request_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &response_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing HTTP response; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?resp.request_id, "failed to publish HTTP response to pubsub"); - } - } - } - MessageBody::ToClientWebSocketMessage(ws_msg) => { - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_msg.web_socket_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket message; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to pubsub"); - } - } - } - MessageBody::ToClientWebSocketOpen(ws_open) => { - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_open.web_socket_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket open; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to pubsub"); - } - } - } - MessageBody::ToClientWebSocketClose(ws_close) => { - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_close.web_socket_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket close; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to pubsub"); - } - } - } - _ => { - // For other message types, we might not need to forward to pubsub - info!( - "Received non-response message from WebSocket, skipping pubsub forward" - ); - continue; - } - } - } - ::std::result::Result::Err(e) => { - error!(?e, "failed to deserialize tunnel message from WebSocket"); - } - } - } - ::std::result::Result::Ok(tokio_tungstenite::tungstenite::Message::Close(_)) => { - info!(?runner_id, "WebSocket closed"); - break; - } - ::std::result::Result::Ok(_) => { - // Ignore other message types - } - Err(e) => { - error!(?e, "WebSocket error"); - break; - } - } - } - - // Clean up connection - connections_clone.write().await.remove(&connection_id); - }); - - // Wait for either task to complete - tokio::select! { - _ = pubsub_to_ws => { - info!("pubsub to WebSocket task completed"); - } - _ = ws_to_pubsub => { - info!("WebSocket to pubsub task completed"); - } - } - - // Clean up - connections.write().await.remove(&connection_id); - info!(?runner_id, "connection closed"); - - Ok(()) -} - -/// Determines if the tunnel is closed by if the UPS service is no longer responding. -fn is_tunnel_closed_error(err: &anyhow::Error) -> bool { - if let Some(err) = err - .chain() - .find_map(|x| x.downcast_ref::()) - && err.group() == "ups" - && err.code() == "request_timeout" - { - true - } else { - false - } -} - -// Helper: Build and send a standard tunnel-closed Close frame (hyper-tungstenite) -fn tunnel_closed_close_msg_hyper() -> WsMessage { - WsMessage::Close(Some(WsCloseFrame { - code: WsCloseCode::Error, - reason: WsUtf8Bytes::from_static("Tunnel closed"), - })) -} - -// Helper: Build and send a standard tunnel-closed Close frame (tokio-tungstenite) -fn tunnel_closed_close_msg_tokio() -> tokio_tungstenite::tungstenite::Message { - tokio_tungstenite::tungstenite::Message::Close(Some( - tokio_tungstenite::tungstenite::protocol::frame::CloseFrame { - code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: tokio_tungstenite::tungstenite::Utf8Bytes::from_static("Tunnel closed"), - }, - )) -} - -// Helper: Send the tunnel-closed Close frame on a hyper-tungstenite sink -async fn send_tunnel_closed_close_hyper(ws_write: &tokio::sync::Mutex) -where - S: futures::Sink + Unpin, -{ - let mut stream = ws_write.lock().await; - let _ = stream.send(tunnel_closed_close_msg_hyper()).await; -} - -// Helper: Send the tunnel-closed Close frame on a tokio-tungstenite sink -async fn send_tunnel_closed_close_tokio(ws_write: &tokio::sync::Mutex) -where - S: futures::Sink + Unpin, -{ - let mut stream = ws_write.lock().await; - let _ = stream.send(tunnel_closed_close_msg_tokio()).await; -} diff --git a/packages/core/pegboard-tunnel/tests/integration.rs b/packages/core/pegboard-tunnel/tests/integration.rs index 2c20bfe638..70051af60a 100644 --- a/packages/core/pegboard-tunnel/tests/integration.rs +++ b/packages/core/pegboard-tunnel/tests/integration.rs @@ -91,7 +91,9 @@ async fn test_pubsub_to_websocket( }; // Serialize the message - let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message))?; + let serialized = versioned::RunnerMessage::serialize_with_embedded_version( + versioned::RunnerMessage::V1(message), + )?; // Publish to pubsub topic using proper subject let topic = TunnelHttpRunnerSubject::new(&runner_id.to_string(), port_name).to_string(); @@ -105,7 +107,7 @@ async fn test_pubsub_to_websocket( match received? { WsMessage::Binary(data) => { // Deserialize and verify the message - let tunnel_msg = versioned::TunnelMessage::deserialize(&data)?; + let tunnel_msg = versioned::RunnerMessage::deserialize_with_embedded_version(&data)?; match tunnel_msg.body { MessageBody::ToServerRequestStart(req) => { assert_eq!(req.request_id, request_id); @@ -150,7 +152,9 @@ async fn test_websocket_to_pubsub( }; // Serialize and send via WebSocket - let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message))?; + let serialized = versioned::RunnerMessage::serialize_with_embedded_version( + versioned::RunnerMessage::V1(message), + )?; ws_stream.send(WsMessage::Binary(serialized.into())).await?; // Wait for message on pubsub @@ -159,7 +163,8 @@ async fn test_websocket_to_pubsub( match received { universalpubsub::pubsub::NextOutput::Message(msg) => { // Deserialize and verify the message - let tunnel_msg = versioned::TunnelMessage::deserialize(&msg.payload)?; + let tunnel_msg = + versioned::RunnerMessage::deserialize_with_embedded_version(&msg.payload)?; match tunnel_msg.body { MessageBody::ToClientResponseStart(resp) => { assert_eq!(resp.request_id, request_id); diff --git a/packages/infra/engine/tests/actors_lifecycle.rs b/packages/infra/engine/tests/actors_lifecycle.rs index 85b991eb4a..46aeff7e9d 100644 --- a/packages/infra/engine/tests/actors_lifecycle.rs +++ b/packages/infra/engine/tests/actors_lifecycle.rs @@ -4,7 +4,7 @@ use std::time::Duration; #[test] fn actor_lifecycle_single_dc() { - common::run(common::TestOpts::new(2), |ctx| async move { + common::run(common::TestOpts::new(1), |ctx| async move { actor_lifecycle_inner(&ctx, false).await; }); } @@ -29,10 +29,6 @@ async fn actor_lifecycle_inner(ctx: &common::TestCtx, multi_dc: bool) { let actor_id = common::create_actor(&namespace, target_dc.guard_port()).await; - // TODO: This is a race condition. we might need to move this after the guard ping since guard - // correctly waits for the actor to start. - tokio::time::sleep(Duration::from_millis(500)).await; - // Test ping via guard let ping_response = common::ping_actor_via_guard(ctx.leader_dc().guard_port(), &actor_id, "main").await; @@ -52,11 +48,10 @@ async fn actor_lifecycle_inner(ctx: &common::TestCtx, multi_dc: bool) { // Destroy tracing::info!("destroying actor"); - tokio::time::sleep(Duration::from_millis(500)).await; common::destroy_actor(&actor_id, &namespace, target_dc.guard_port()).await; - tokio::time::sleep(Duration::from_millis(500)).await; // Validate runner state + tokio::time::sleep(Duration::from_millis(500)).await; assert!( !runner.has_actor(&actor_id).await, "Runner should not have the actor after destroy" diff --git a/packages/services/pegboard/src/ops/runner/get_by_key.rs b/packages/services/pegboard/src/ops/runner/get_by_key.rs new file mode 100644 index 0000000000..5a6ea63bf1 --- /dev/null +++ b/packages/services/pegboard/src/ops/runner/get_by_key.rs @@ -0,0 +1,51 @@ +use anyhow::*; +use gas::prelude::*; +use rivet_types::runners::Runner; +use udb_util::{SERIALIZABLE, TxnExt}; +use universaldb as udb; + +use crate::keys; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Input { + pub namespace_id: Id, + pub name: String, + pub key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Output { + pub runner: Option, +} + +#[operation] +pub async fn pegboard_runner_get_by_key(ctx: &OperationCtx, input: &Input) -> Result { + let dc_name = ctx.config().dc_name()?.to_string(); + + let runner = ctx + .udb()? + .run(|tx, _mc| { + let dc_name = dc_name.to_string(); + let input = input.clone(); + async move { + let txs = tx.subspace(keys::subspace()); + + // Look up runner by key + let runner_by_key_key = + keys::ns::RunnerByKeyKey::new(input.namespace_id, input.name, input.key); + + let runner_data = txs.read_opt(&runner_by_key_key, SERIALIZABLE).await?; + + if let Some(data) = runner_data { + // Get full runner details using the runner_id + let runner = super::get::get_inner(&dc_name, &tx, data.runner_id).await?; + std::result::Result::<_, udb::FdbBindingError>::Ok(runner) + } else { + std::result::Result::<_, udb::FdbBindingError>::Ok(None) + } + } + }) + .await?; + + Ok(Output { runner }) +} diff --git a/packages/services/pegboard/src/ops/runner/mod.rs b/packages/services/pegboard/src/ops/runner/mod.rs index cd56c69267..1885c60f23 100644 --- a/packages/services/pegboard/src/ops/runner/mod.rs +++ b/packages/services/pegboard/src/ops/runner/mod.rs @@ -1,4 +1,5 @@ pub mod get; +pub mod get_by_key; pub mod list_for_ns; pub mod list_names; pub mod update_alloc_idx; diff --git a/packages/services/pegboard/src/pubsub_subjects.rs b/packages/services/pegboard/src/pubsub_subjects.rs index 13d39b1704..ae9e9438b6 100644 --- a/packages/services/pegboard/src/pubsub_subjects.rs +++ b/packages/services/pegboard/src/pubsub_subjects.rs @@ -1,77 +1,41 @@ -use rivet_util::Id; +use gas::prelude::*; -pub struct TunnelHttpRunnerSubject<'a> { - runner_id: Id, +pub struct TunnelRunnerReceiverSubject<'a> { + runner_key: &'a str, port_name: &'a str, } -impl<'a> TunnelHttpRunnerSubject<'a> { - pub fn new(runner_id: Id, port_name: &'a str) -> Self { +impl<'a> TunnelRunnerReceiverSubject<'a> { + pub fn new(runner_key: &'a str, port_name: &'a str) -> Self { Self { - runner_id, + runner_key, port_name, } } } -impl std::fmt::Display for TunnelHttpRunnerSubject<'_> { +impl std::fmt::Display for TunnelRunnerReceiverSubject<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "pegboard.tunnel.http.runner.{}.{}", - self.runner_id, self.port_name + "pegboard.tunnel.runner_receiver.{}.{}", + self.runner_key, self.port_name ) } } -pub struct TunnelHttpResponseSubject<'a> { - runner_id: Id, - port_name: &'a str, - request_id: u64, +pub struct TunnelGatewayReceiverSubject { + gateway_id: Uuid, } -impl<'a> TunnelHttpResponseSubject<'a> { - pub fn new(runner_id: Id, port_name: &'a str, request_id: u64) -> Self { - Self { - runner_id, - port_name, - request_id, - } +impl<'a> TunnelGatewayReceiverSubject { + pub fn new(gateway_id: Uuid) -> Self { + Self { gateway_id } } } -impl std::fmt::Display for TunnelHttpResponseSubject<'_> { +impl std::fmt::Display for TunnelGatewayReceiverSubject { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "pegboard.tunnel.http.request.{}.{}.{}", - self.runner_id, self.port_name, self.request_id - ) - } -} - -pub struct TunnelHttpWebSocketSubject<'a> { - runner_id: Id, - port_name: &'a str, - websocket_id: u64, -} - -impl<'a> TunnelHttpWebSocketSubject<'a> { - pub fn new(runner_id: Id, port_name: &'a str, websocket_id: u64) -> Self { - Self { - runner_id, - port_name, - websocket_id, - } - } -} - -impl std::fmt::Display for TunnelHttpWebSocketSubject<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "pegboard.tunnel.http.websocket.{}.{}.{}", - self.runner_id, self.port_name, self.websocket_id - ) + write!(f, "pegboard.gateway.receiver.{}", self.gateway_id) } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 62a80f8ad4..91185e953a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -685,6 +685,9 @@ importers: '@rivetkit/engine-tunnel-protocol': specifier: workspace:* version: link:../tunnel-protocol + uuid: + specifier: ^12.0.0 + version: 12.0.0 ws: specifier: ^8.18.3 version: 8.18.3 @@ -7241,6 +7244,10 @@ packages: util-deprecate@1.0.2: resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==} + uuid@12.0.0: + resolution: {integrity: sha512-USe1zesMYh4fjCA8ZH5+X5WIVD0J4V1Jksm1bFTVBX2F/cwSXt0RO5w/3UXbdLKmZX65MiWV+hwhSS8p6oBTGA==} + hasBin: true + validator@13.15.15: resolution: {integrity: sha512-BgWVbCI72aIQy937xbawcs+hrVaN/CZ2UwutgaJ36hGqRrLNM+f5LUT/YPRbo8IV/ASeFzXszezV+y2+rq3l8A==} engines: {node: '>= 0.10'} @@ -14748,6 +14755,8 @@ snapshots: util-deprecate@1.0.2: {} + uuid@12.0.0: {} + validator@13.15.15: {} vary@1.1.2: {} diff --git a/sdks/rust/runner-protocol/src/protocol.rs b/sdks/rust/runner-protocol/src/protocol.rs index 4ca049f634..a037941e2a 100644 --- a/sdks/rust/runner-protocol/src/protocol.rs +++ b/sdks/rust/runner-protocol/src/protocol.rs @@ -20,9 +20,7 @@ pub enum ToClient { #[serde(rename_all = "snake_case")] pub enum ToServer { Init { - runner_id: Option, name: String, - key: String, version: u32, total_slots: u32, diff --git a/sdks/rust/runner-protocol/src/versioned.rs b/sdks/rust/runner-protocol/src/versioned.rs index 2b61c2c2dc..34c54168b6 100644 --- a/sdks/rust/runner-protocol/src/versioned.rs +++ b/sdks/rust/runner-protocol/src/versioned.rs @@ -285,9 +285,7 @@ impl TryFrom for protocol::ToServer { fn try_from(value: v1::ToServer) -> Result { match value { v1::ToServer::ToServerInit(init) => Ok(protocol::ToServer::Init { - runner_id: init.runner_id.map(|id| util::Id::parse(&id)).transpose()?, name: init.name, - key: init.key, version: init.version, total_slots: init.total_slots, addresses_http: init diff --git a/sdks/rust/tunnel-protocol/build.rs b/sdks/rust/tunnel-protocol/build.rs index 453be2c763..f43ed92983 100644 --- a/sdks/rust/tunnel-protocol/build.rs +++ b/sdks/rust/tunnel-protocol/build.rs @@ -1,4 +1,8 @@ -use std::{env, fs, path::Path}; +use std::{ + env, fs, + path::{Path, PathBuf}, + process::Command, +}; use indoc::formatdoc; @@ -52,6 +56,78 @@ mod rust { } } +mod typescript { + use super::*; + + pub fn generate_sdk(schema_dir: &Path) { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let workspace_root = Path::new(&manifest_dir) + .parent() + .and_then(|p| p.parent()) + .and_then(|p| p.parent()) + .expect("Failed to find workspace root"); + + let sdk_dir = workspace_root + .join("sdks") + .join("typescript") + .join("tunnel-protocol"); + let src_dir = sdk_dir.join("src"); + + let highest_version_path = super::find_highest_version(schema_dir); + + let _ = fs::remove_dir_all(&src_dir); + if let Err(e) = fs::create_dir_all(&src_dir) { + panic!("Failed to create SDK directory: {}", e); + } + + let output = + Command::new(workspace_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js")) + .arg("compile") + .arg("--generator") + .arg("ts") + .arg(highest_version_path) + .arg("-o") + .arg(src_dir.join("index.ts")) + .output() + .expect("Failed to execute bare compiler for TypeScript"); + + if !output.status.success() { + panic!( + "BARE TypeScript generation failed: {}", + String::from_utf8_lossy(&output.stderr), + ); + } + } +} + +fn find_highest_version(schema_dir: &Path) -> PathBuf { + let mut highest_version = 0; + let mut highest_version_path = PathBuf::new(); + + for entry in fs::read_dir(schema_dir).unwrap().flatten() { + if !entry.path().is_dir() { + let path = entry.path(); + let bare_name = path + .file_name() + .unwrap() + .to_str() + .unwrap() + .split_once('.') + .unwrap() + .0; + + if let Ok(version) = bare_name[1..].parse::() { + if version > highest_version { + highest_version = version; + highest_version_path = path; + } + } + } + } + + highest_version_path +} + fn main() { let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); let workspace_root = Path::new(&manifest_dir) @@ -68,4 +144,15 @@ fn main() { println!("cargo:rerun-if-changed={}", schema_dir.display()); rust::generate_sdk(&schema_dir); + + // Check if cli.js exists before attempting TypeScript generation + let cli_js_path = workspace_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js"); + if cli_js_path.exists() { + typescript::generate_sdk(&schema_dir); + } else { + println!( + "cargo:warning=TypeScript SDK generation skipped: cli.js not found at {}. Run `pnpm install` to install.", + cli_js_path.display() + ); + } } diff --git a/sdks/rust/tunnel-protocol/src/versioned.rs b/sdks/rust/tunnel-protocol/src/versioned.rs index d37e5d5447..f5208f27fb 100644 --- a/sdks/rust/tunnel-protocol/src/versioned.rs +++ b/sdks/rust/tunnel-protocol/src/versioned.rs @@ -3,20 +3,20 @@ use versioned_data_util::OwnedVersionedData; use crate::{PROTOCOL_VERSION, generated::v1}; -pub enum TunnelMessage { - V1(v1::TunnelMessage), +pub enum RunnerMessage { + V1(v1::RunnerMessage), } -impl OwnedVersionedData for TunnelMessage { - type Latest = v1::TunnelMessage; +impl OwnedVersionedData for RunnerMessage { + type Latest = v1::RunnerMessage; - fn latest(latest: v1::TunnelMessage) -> Self { - TunnelMessage::V1(latest) + fn latest(latest: v1::RunnerMessage) -> Self { + RunnerMessage::V1(latest) } fn into_latest(self) -> Result { #[allow(irrefutable_let_patterns)] - if let TunnelMessage::V1(data) = self { + if let RunnerMessage::V1(data) = self { Ok(data) } else { bail!("version not latest"); @@ -25,20 +25,64 @@ impl OwnedVersionedData for TunnelMessage { fn deserialize_version(payload: &[u8], version: u16) -> Result { match version { - 1 => Ok(TunnelMessage::V1(serde_bare::from_slice(payload)?)), + 1 => Ok(RunnerMessage::V1(serde_bare::from_slice(payload)?)), _ => bail!("invalid version: {version}"), } } fn serialize_version(self, _version: u16) -> Result> { match self { - TunnelMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), + RunnerMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), } } } -impl TunnelMessage { - pub fn deserialize(buf: &[u8]) -> Result { +impl RunnerMessage { + pub fn deserialize(buf: &[u8]) -> Result { + ::deserialize(buf, PROTOCOL_VERSION) + } + + pub fn serialize(self) -> Result> { + ::serialize(self, PROTOCOL_VERSION) + } +} + +pub enum PubSubMessage { + V1(v1::PubSubMessage), +} + +impl OwnedVersionedData for PubSubMessage { + type Latest = v1::PubSubMessage; + + fn latest(latest: v1::PubSubMessage) -> Self { + PubSubMessage::V1(latest) + } + + fn into_latest(self) -> Result { + #[allow(irrefutable_let_patterns)] + if let PubSubMessage::V1(data) = self { + Ok(data) + } else { + bail!("version not latest"); + } + } + + fn deserialize_version(payload: &[u8], version: u16) -> Result { + match version { + 1 => Ok(PubSubMessage::V1(serde_bare::from_slice(payload)?)), + _ => bail!("invalid version: {version}"), + } + } + + fn serialize_version(self, _version: u16) -> Result> { + match self { + PubSubMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), + } + } +} + +impl PubSubMessage { + pub fn deserialize(buf: &[u8]) -> Result { ::deserialize(buf, PROTOCOL_VERSION) } diff --git a/sdks/schemas/runner-protocol/v1.bare b/sdks/schemas/runner-protocol/v1.bare index ed99e59afb..5db5afcea4 100644 --- a/sdks/schemas/runner-protocol/v1.bare +++ b/sdks/schemas/runner-protocol/v1.bare @@ -133,9 +133,7 @@ type CommandWrapper struct { } type ToServerInit struct { - runnerId: optional name: str - key: str version: u32 totalSlots: u32 addressesHttp: optional> diff --git a/sdks/schemas/tunnel-protocol/v1.bare b/sdks/schemas/tunnel-protocol/v1.bare index 5405587c1b..f9e0e9c63f 100644 --- a/sdks/schemas/tunnel-protocol/v1.bare +++ b/sdks/schemas/tunnel-protocol/v1.bare @@ -1,15 +1,13 @@ -type RequestId u64 -type WebSocketId u64 +type RequestId data[16] # UUIDv4 +type MessageId data[16] # UUIDv4 type Id str -type StreamFinishReason enum { - COMPLETE - ABORT -} -# MARK: HTTP Request Forwarding +# MARK: Ack +type Ack void + +# MARK: HTTP type ToServerRequestStart struct { - requestId: RequestId actorId: Id method: str path: str @@ -19,17 +17,13 @@ type ToServerRequestStart struct { } type ToServerRequestChunk struct { - requestId: RequestId body: data + finish: bool } -type ToServerRequestFinish struct { - requestId: RequestId - reason: StreamFinishReason -} +type ToServerRequestAbort void type ToClientResponseStart struct { - requestId: RequestId status: u16 headers: map body: optional @@ -37,60 +31,52 @@ type ToClientResponseStart struct { } type ToClientResponseChunk struct { - requestId: RequestId body: data + finish: bool } -type ToClientResponseFinish struct { - requestId: RequestId - reason: StreamFinishReason -} +type ToClientResponseAbort void -# MARK: WebSocket Forwarding +# MARK: WebSocket type ToServerWebSocketOpen struct { actorId: Id - webSocketId: WebSocketId path: str headers: map } type ToServerWebSocketMessage struct { - webSocketId: WebSocketId data: data binary: bool } type ToServerWebSocketClose struct { - webSocketId: WebSocketId code: optional reason: optional } -type ToClientWebSocketOpen struct { - webSocketId: WebSocketId -} +type ToClientWebSocketOpen void type ToClientWebSocketMessage struct { - webSocketId: WebSocketId data: data binary: bool } type ToClientWebSocketClose struct { - webSocketId: WebSocketId code: optional reason: optional } # MARK: Message -type MessageBody union { +type MessageKind union { + Ack | + # HTTP ToServerRequestStart | ToServerRequestChunk | - ToServerRequestFinish | + ToServerRequestAbort | ToClientResponseStart | ToClientResponseChunk | - ToClientResponseFinish | + ToClientResponseAbort | # WebSocket ToServerWebSocketOpen | @@ -101,7 +87,19 @@ type MessageBody union { ToClientWebSocketClose } -# Main tunnel message -type TunnelMessage struct { - body: MessageBody +# MARK: Message sent over tunnel WebSocket +type RunnerMessage struct { + requestId: RequestId + messageId: MessageId + messageKind: MessageKind +} + +# MARK: Message sent over UPS +type PubSubMessage struct { + requestId: RequestId + messageId: MessageId + # Subject to send replies to. Only sent when opening a new request from gateway -> runner. + replyTo: optional + messageKind: MessageKind } + diff --git a/sdks/typescript/runner-protocol/src/index.ts b/sdks/typescript/runner-protocol/src/index.ts index 45a947c051..96c00212c6 100644 --- a/sdks/typescript/runner-protocol/src/index.ts +++ b/sdks/typescript/runner-protocol/src/index.ts @@ -590,18 +590,7 @@ export function writeCommandWrapper(bc: bare.ByteCursor, x: CommandWrapper): voi writeCommand(bc, x.inner) } -function read3(bc: bare.ByteCursor): Id | null { - return bare.readBool(bc) ? readId(bc) : null -} - -function write3(bc: bare.ByteCursor, x: Id | null): void { - bare.writeBool(bc, x != null) - if (x != null) { - writeId(bc, x) - } -} - -function read4(bc: bare.ByteCursor): ReadonlyMap { +function read3(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -616,7 +605,7 @@ function read4(bc: bare.ByteCursor): ReadonlyMap { return result } -function write4(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write3(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -624,18 +613,18 @@ function write4(bc: bare.ByteCursor, x: ReadonlyMap): } } -function read5(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read4(bc) : null +function read4(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read3(bc) : null } -function write5(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write4(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write4(bc, x) + write3(bc, x) } } -function read6(bc: bare.ByteCursor): ReadonlyMap { +function read5(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -650,7 +639,7 @@ function read6(bc: bare.ByteCursor): ReadonlyMap { return result } -function write6(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write5(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -658,18 +647,18 @@ function write6(bc: bare.ByteCursor, x: ReadonlyMap): } } -function read7(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read6(bc) : null +function read6(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read5(bc) : null } -function write7(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write6(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write6(bc, x) + write5(bc, x) } } -function read8(bc: bare.ByteCursor): ReadonlyMap { +function read7(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -684,7 +673,7 @@ function read8(bc: bare.ByteCursor): ReadonlyMap { return result } -function write8(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write7(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -692,18 +681,18 @@ function write8(bc: bare.ByteCursor, x: ReadonlyMap): } } -function read9(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read8(bc) : null +function read8(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read7(bc) : null } -function write9(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write8(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write8(bc, x) + write7(bc, x) } } -function read10(bc: bare.ByteCursor): ReadonlyMap { +function read9(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -718,7 +707,7 @@ function read10(bc: bare.ByteCursor): ReadonlyMap { return result } -function write10(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write9(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -726,22 +715,22 @@ function write10(bc: bare.ByteCursor, x: ReadonlyMap): void { } } -function read11(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read10(bc) : null +function read10(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read9(bc) : null } -function write11(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write10(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write10(bc, x) + write9(bc, x) } } -function read12(bc: bare.ByteCursor): Json | null { +function read11(bc: bare.ByteCursor): Json | null { return bare.readBool(bc) ? readJson(bc) : null } -function write12(bc: bare.ByteCursor, x: Json | null): void { +function write11(bc: bare.ByteCursor, x: Json | null): void { bare.writeBool(bc, x != null) if (x != null) { writeJson(bc, x) @@ -749,9 +738,7 @@ function write12(bc: bare.ByteCursor, x: Json | null): void { } export type ToServerInit = { - readonly runnerId: Id | null readonly name: string - readonly key: string readonly version: u32 readonly totalSlots: u32 readonly addressesHttp: ReadonlyMap | null @@ -764,32 +751,28 @@ export type ToServerInit = { export function readToServerInit(bc: bare.ByteCursor): ToServerInit { return { - runnerId: read3(bc), name: bare.readString(bc), - key: bare.readString(bc), version: bare.readU32(bc), totalSlots: bare.readU32(bc), - addressesHttp: read5(bc), - addressesTcp: read7(bc), - addressesUdp: read9(bc), + addressesHttp: read4(bc), + addressesTcp: read6(bc), + addressesUdp: read8(bc), lastCommandIdx: read1(bc), - prepopulateActorNames: read11(bc), - metadata: read12(bc), + prepopulateActorNames: read10(bc), + metadata: read11(bc), } } export function writeToServerInit(bc: bare.ByteCursor, x: ToServerInit): void { - write3(bc, x.runnerId) bare.writeString(bc, x.name) - bare.writeString(bc, x.key) bare.writeU32(bc, x.version) bare.writeU32(bc, x.totalSlots) - write5(bc, x.addressesHttp) - write7(bc, x.addressesTcp) - write9(bc, x.addressesUdp) + write4(bc, x.addressesHttp) + write6(bc, x.addressesTcp) + write8(bc, x.addressesUdp) write1(bc, x.lastCommandIdx) - write11(bc, x.prepopulateActorNames) - write12(bc, x.metadata) + write10(bc, x.prepopulateActorNames) + write11(bc, x.metadata) } export type ToServerEvents = readonly EventWrapper[] @@ -843,7 +826,7 @@ export function writeToServerPing(bc: bare.ByteCursor, x: ToServerPing): void { bare.writeI64(bc, x.ts) } -function read13(bc: bare.ByteCursor): readonly KvKey[] { +function read12(bc: bare.ByteCursor): readonly KvKey[] { const len = bare.readUintSafe(bc) if (len === 0) { return [] @@ -855,7 +838,7 @@ function read13(bc: bare.ByteCursor): readonly KvKey[] { return result } -function write13(bc: bare.ByteCursor, x: readonly KvKey[]): void { +function write12(bc: bare.ByteCursor, x: readonly KvKey[]): void { bare.writeUintSafe(bc, x.length) for (let i = 0; i < x.length; i++) { writeKvKey(bc, x[i]) @@ -868,30 +851,30 @@ export type KvGetRequest = { export function readKvGetRequest(bc: bare.ByteCursor): KvGetRequest { return { - keys: read13(bc), + keys: read12(bc), } } export function writeKvGetRequest(bc: bare.ByteCursor, x: KvGetRequest): void { - write13(bc, x.keys) + write12(bc, x.keys) } -function read14(bc: bare.ByteCursor): boolean | null { +function read13(bc: bare.ByteCursor): boolean | null { return bare.readBool(bc) ? bare.readBool(bc) : null } -function write14(bc: bare.ByteCursor, x: boolean | null): void { +function write13(bc: bare.ByteCursor, x: boolean | null): void { bare.writeBool(bc, x != null) if (x != null) { bare.writeBool(bc, x) } } -function read15(bc: bare.ByteCursor): u64 | null { +function read14(bc: bare.ByteCursor): u64 | null { return bare.readBool(bc) ? bare.readU64(bc) : null } -function write15(bc: bare.ByteCursor, x: u64 | null): void { +function write14(bc: bare.ByteCursor, x: u64 | null): void { bare.writeBool(bc, x != null) if (x != null) { bare.writeU64(bc, x) @@ -907,18 +890,18 @@ export type KvListRequest = { export function readKvListRequest(bc: bare.ByteCursor): KvListRequest { return { query: readKvListQuery(bc), - reverse: read14(bc), - limit: read15(bc), + reverse: read13(bc), + limit: read14(bc), } } export function writeKvListRequest(bc: bare.ByteCursor, x: KvListRequest): void { writeKvListQuery(bc, x.query) - write14(bc, x.reverse) - write15(bc, x.limit) + write13(bc, x.reverse) + write14(bc, x.limit) } -function read16(bc: bare.ByteCursor): readonly KvValue[] { +function read15(bc: bare.ByteCursor): readonly KvValue[] { const len = bare.readUintSafe(bc) if (len === 0) { return [] @@ -930,7 +913,7 @@ function read16(bc: bare.ByteCursor): readonly KvValue[] { return result } -function write16(bc: bare.ByteCursor, x: readonly KvValue[]): void { +function write15(bc: bare.ByteCursor, x: readonly KvValue[]): void { bare.writeUintSafe(bc, x.length) for (let i = 0; i < x.length; i++) { writeKvValue(bc, x[i]) @@ -944,14 +927,14 @@ export type KvPutRequest = { export function readKvPutRequest(bc: bare.ByteCursor): KvPutRequest { return { - keys: read13(bc), - values: read16(bc), + keys: read12(bc), + values: read15(bc), } } export function writeKvPutRequest(bc: bare.ByteCursor, x: KvPutRequest): void { - write13(bc, x.keys) - write16(bc, x.values) + write12(bc, x.keys) + write15(bc, x.values) } export type KvDeleteRequest = { @@ -960,12 +943,12 @@ export type KvDeleteRequest = { export function readKvDeleteRequest(bc: bare.ByteCursor): KvDeleteRequest { return { - keys: read13(bc), + keys: read12(bc), } } export function writeKvDeleteRequest(bc: bare.ByteCursor, x: KvDeleteRequest): void { - write13(bc, x.keys) + write12(bc, x.keys) } export type KvDropRequest = null @@ -1214,7 +1197,7 @@ export function writeKvErrorResponse(bc: bare.ByteCursor, x: KvErrorResponse): v bare.writeString(bc, x.message) } -function read17(bc: bare.ByteCursor): readonly KvMetadata[] { +function read16(bc: bare.ByteCursor): readonly KvMetadata[] { const len = bare.readUintSafe(bc) if (len === 0) { return [] @@ -1226,7 +1209,7 @@ function read17(bc: bare.ByteCursor): readonly KvMetadata[] { return result } -function write17(bc: bare.ByteCursor, x: readonly KvMetadata[]): void { +function write16(bc: bare.ByteCursor, x: readonly KvMetadata[]): void { bare.writeUintSafe(bc, x.length) for (let i = 0; i < x.length; i++) { writeKvMetadata(bc, x[i]) @@ -1241,16 +1224,16 @@ export type KvGetResponse = { export function readKvGetResponse(bc: bare.ByteCursor): KvGetResponse { return { - keys: read13(bc), - values: read16(bc), - metadata: read17(bc), + keys: read12(bc), + values: read15(bc), + metadata: read16(bc), } } export function writeKvGetResponse(bc: bare.ByteCursor, x: KvGetResponse): void { - write13(bc, x.keys) - write16(bc, x.values) - write17(bc, x.metadata) + write12(bc, x.keys) + write15(bc, x.values) + write16(bc, x.metadata) } export type KvListResponse = { @@ -1261,16 +1244,16 @@ export type KvListResponse = { export function readKvListResponse(bc: bare.ByteCursor): KvListResponse { return { - keys: read13(bc), - values: read16(bc), - metadata: read17(bc), + keys: read12(bc), + values: read15(bc), + metadata: read16(bc), } } export function writeKvListResponse(bc: bare.ByteCursor, x: KvListResponse): void { - write13(bc, x.keys) - write16(bc, x.values) - write17(bc, x.metadata) + write12(bc, x.keys) + write15(bc, x.values) + write16(bc, x.metadata) } export type KvPutResponse = null diff --git a/sdks/typescript/runner/package.json b/sdks/typescript/runner/package.json index 0b7bfdf4c3..984f065374 100644 --- a/sdks/typescript/runner/package.json +++ b/sdks/typescript/runner/package.json @@ -22,6 +22,7 @@ "dependencies": { "@rivetkit/engine-runner-protocol": "workspace:*", "@rivetkit/engine-tunnel-protocol": "workspace:*", + "uuid": "^12.0.0", "ws": "^8.18.3" }, "devDependencies": { diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 92be4cda50..7d26a37925 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -1,16 +1,18 @@ import WebSocket from "ws"; import { importWebSocket } from "./websocket.js"; import * as protocol from "@rivetkit/engine-runner-protocol"; -import { unreachable, calculateBackoff } from "./utils.js"; -import { Tunnel } from "./tunnel.js"; -import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter.js"; +import { unreachable, calculateBackoff } from "./utils"; +import { Tunnel } from "./tunnel"; +import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; const KV_EXPIRE: number = 30_000; -interface ActorInstance { +export interface ActorInstance { actorId: string; generation: number; config: ActorConfig; + requests: Set; // Track active request IDs + webSockets: Set; // Track active WebSocket IDs } export interface ActorConfig { @@ -60,6 +62,11 @@ interface KvRequestEntry { export class Runner { #config: RunnerConfig; + + get config(): RunnerConfig { + return this.#config; + } + #actors: Map = new Map(); #actorWebSockets: Map> = new Map(); @@ -110,7 +117,7 @@ export class Runner { // MARK: Manage actors sleepActor(actorId: string, generation?: number) { - const actor = this.#getActor(actorId, generation); + const actor = this.getActor(actorId, generation); if (!actor) return; // Keep the actor instance in memory during sleep @@ -126,7 +133,7 @@ export class Runner { // Unregister actor from tunnel if (this.#tunnel) { - this.#tunnel.unregisterActor(actorId); + this.#tunnel.unregisterActor(actor); } this.#sendActorStateUpdate(actorId, actor.generation, "stopped"); @@ -147,7 +154,7 @@ export class Runner { } } - #getActor(actorId: string, generation?: number): ActorInstance | undefined { + getActor(actorId: string, generation?: number): ActorInstance | undefined { const actor = this.#actors.get(actorId); if (!actor) { console.error(`Actor ${actorId} not found`); @@ -363,10 +370,10 @@ export class Runner { const wsEndpoint = endpoint .replace("http://", "ws://") .replace("https://", "wss://"); - return `${wsEndpoint}/v1?namespace=${encodeURIComponent(this.#config.namespace)}`; + return `${wsEndpoint}?protocol_version=1&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${encodeURIComponent(this.#config.runnerKey)}`; } - get pegboardRelayUrl() { + get pegboardTunnelUrl() { const endpoint = this.#config.pegboardRelayEndpoint || this.#config.pegboardEndpoint || @@ -374,26 +381,19 @@ export class Runner { const wsEndpoint = endpoint .replace("http://", "ws://") .replace("https://", "wss://"); - // Include runner ID if we have it - if (this.runnerId) { - return `${wsEndpoint}/tunnel?namespace=${encodeURIComponent(this.#config.namespace)}&runner_id=${this.runnerId}`; - } - return `${wsEndpoint}/tunnel?namespace=${encodeURIComponent(this.#config.namespace)}`; + return `${wsEndpoint}?protocol_version=1&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${this.#config.runnerKey}`; } async #openTunnelAndWait(): Promise { return new Promise((resolve, reject) => { - const url = this.pegboardRelayUrl; + const url = this.pegboardTunnelUrl; //console.log("[RUNNER] Opening tunnel to:", url); //console.log("[RUNNER] Current runner ID:", this.runnerId || "none"); //console.log("[RUNNER] Active actors count:", this.#actors.size); let connected = false; - this.#tunnel = new Tunnel(url); - this.#tunnel.setCallbacks({ - fetch: this.#config.fetch, - websocket: this.#config.websocket, + this.#tunnel = new Tunnel(this, url, { onConnected: () => { if (!connected) { connected = true; @@ -410,35 +410,9 @@ export class Runner { }, }); this.#tunnel.start(); - - // Re-register all active actors with the new tunnel - for (const actorId of this.#actors.keys()) { - //console.log("[RUNNER] Re-registering actor with tunnel:", actorId); - this.#tunnel.registerActor(actorId); - } }); } - #openTunnel() { - const url = this.pegboardRelayUrl; - //console.log("[RUNNER] Opening tunnel to:", url); - //console.log("[RUNNER] Current runner ID:", this.runnerId || "none"); - //console.log("[RUNNER] Active actors count:", this.#actors.size); - - this.#tunnel = new Tunnel(url); - this.#tunnel.setCallbacks({ - fetch: this.#config.fetch, - websocket: this.#config.websocket, - }); - this.#tunnel.start(); - - // Re-register all active actors with the new tunnel - for (const actorId of this.#actors.keys()) { - //console.log("[RUNNER] Re-registering actor with tunnel:", actorId); - this.#tunnel.registerActor(actorId); - } - } - // MARK: Runner protocol async #openPegboardWebSocket() { const WS = await importWebSocket(); @@ -469,9 +443,7 @@ export class Runner { // Send init message const init: protocol.ToServerInit = { - runnerId: this.runnerId || null, name: this.#config.runnerName, - key: this.#config.runnerKey, version: this.#config.version, totalSlots: this.#config.totalSlots, addressesHttp: new Map(), // No addresses needed with tunnel @@ -560,18 +532,6 @@ export class Runner { // runnerLostThreshold: this.#runnerLostThreshold, //}); - // Only reopen tunnel if we didn't have a runner ID before - // This happens on reconnection after losing connection - if (!hadRunnerId && this.runnerId) { - // Reopen tunnel with runner ID - //console.log("[RUNNER] Received runner ID, reopening tunnel"); - if (this.#tunnel) { - //console.log("[RUNNER] Shutting down existing tunnel"); - this.#tunnel.shutdown(); - } - this.#openTunnel(); - } - // Resend events that haven't been acknowledged this.#resendUnacknowledgedEvents(init.lastEventIdx); @@ -664,21 +624,12 @@ export class Runner { actorId, generation, config: actorConfig, + requests: new Set(), + webSockets: new Set(), }; this.#actors.set(actorId, instance); - // Register actor with tunnel - if (this.#tunnel) { - //console.log("[RUNNER] Registering new actor with tunnel:", actorId); - this.#tunnel.registerActor(actorId); - } else { - console.error( - "[RUNNER] WARNING: No tunnel available to register actor:", - actorId, - ); - } - this.#sendActorStateUpdate(actorId, generation, "running"); // TODO: Add timeout to onActorStart diff --git a/sdks/typescript/runner/src/tunnel.ts b/sdks/typescript/runner/src/tunnel.ts index 7c868c44ce..4a00c878e2 100644 --- a/sdks/typescript/runner/src/tunnel.ts +++ b/sdks/typescript/runner/src/tunnel.ts @@ -1,162 +1,269 @@ import WebSocket from "ws"; import * as tunnel from "@rivetkit/engine-tunnel-protocol"; -import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter.js"; -import { calculateBackoff } from "./utils.js"; +import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; +import { calculateBackoff } from "./utils"; +import type { Runner, ActorInstance } from "./mod"; +import { v4 as uuidv4 } from "uuid"; + +const GC_INTERVAL = 60000; // 60 seconds +const MESSAGE_ACK_TIMEOUT = 5000; // 5 seconds + +interface PendingRequest { + resolve: (response: Response) => void; + reject: (error: Error) => void; + streamController?: ReadableStreamDefaultController; + actorId?: string; +} + +interface TunnelCallbacks { + onConnected(): void; + onDisconnected(): void; +} + +interface PendingMessage { + sentAt: number; + requestIdStr: string; +} export class Tunnel { #pegboardTunnelUrl: string; - #ws?: WebSocket; - #pendingRequests: Map void; - reject: (error: Error) => void; - streamController?: ReadableStreamDefaultController; - actorId?: string; - }> = new Map(); - #webSockets: Map = new Map(); + + #runner: Runner; + + #tunnelWs?: WebSocket; #shutdown = false; #reconnectTimeout?: NodeJS.Timeout; #reconnectAttempt = 0; - - // Track actors and their connections - #activeActors: Set = new Set(); - #actorRequests: Map> = new Map(); - #actorWebSockets: Map> = new Map(); - - // Callbacks - #onConnected?: () => void; - #onDisconnected?: () => void; - #fetchHandler?: (actorId: string, request: Request) => Promise; - #websocketHandler?: (actorId: string, ws: any, request: Request) => Promise; - - constructor(pegboardTunnelUrl: string) { - this.#pegboardTunnelUrl = pegboardTunnelUrl; - } - setCallbacks(options: { - onConnected?: () => void; - onDisconnected?: () => void; - fetch?: (actorId: string, request: Request) => Promise; - websocket?: (actorId: string, ws: any, request: Request) => Promise; - }) { - this.#onConnected = options.onConnected; - this.#onDisconnected = options.onDisconnected; - this.#fetchHandler = options.fetch; - this.#websocketHandler = options.websocket; + #actorPendingRequests: Map = new Map(); + #actorWebSockets: Map = new Map(); + + #pendingMessages: Map = new Map(); + #gcInterval?: NodeJS.Timeout; + + #callbacks: TunnelCallbacks; + + constructor( + runner: Runner, + pegboardTunnelUrl: string, + callbacks: TunnelCallbacks, + ) { + this.#pegboardTunnelUrl = pegboardTunnelUrl; + this.#runner = runner; + this.#callbacks = callbacks; } start(): void { - if (this.#ws?.readyState === WebSocket.OPEN) { + if (this.#tunnelWs?.readyState === WebSocket.OPEN) { return; } - + this.#connect(); + this.#startGarbageCollector(); } shutdown() { this.#shutdown = true; - + if (this.#reconnectTimeout) { clearTimeout(this.#reconnectTimeout); this.#reconnectTimeout = undefined; } - if (this.#ws) { - this.#ws.close(); - this.#ws = undefined; + if (this.#gcInterval) { + clearInterval(this.#gcInterval); + this.#gcInterval = undefined; + } + + if (this.#tunnelWs) { + this.#tunnelWs.close(); + this.#tunnelWs = undefined; } + // TODO: Should we use unregisterActor instead + // Reject all pending requests - for (const [_, request] of this.#pendingRequests) { + for (const [_, request] of this.#actorPendingRequests) { request.reject(new Error("Tunnel shutting down")); } - this.#pendingRequests.clear(); + this.#actorPendingRequests.clear(); // Close all WebSockets - for (const [_, ws] of this.#webSockets) { + for (const [_, ws] of this.#actorWebSockets) { ws.close(); } - this.#webSockets.clear(); - - // Clear actor tracking - this.#activeActors.clear(); - this.#actorRequests.clear(); this.#actorWebSockets.clear(); } - registerActor(actorId: string) { - this.#activeActors.add(actorId); - this.#actorRequests.set(actorId, new Set()); - this.#actorWebSockets.set(actorId, new Set()); + #sendMessage(requestId: tunnel.RequestId, messageKind: tunnel.MessageKind) { + if (!this.#tunnelWs || this.#tunnelWs.readyState !== WebSocket.OPEN) { + console.warn("Cannot send tunnel message, WebSocket not connected"); + return; + } + + // Build message + const messageId = generateUuidBuffer(); + + const requestIdStr = bufferToString(requestId); + this.#pendingMessages.set(bufferToString(messageId), { + sentAt: Date.now(), + requestIdStr, + }); + + // Send message + const message: tunnel.RunnerMessage = { + requestId, + messageId, + messageKind, + }; + + const encoded = tunnel.encodeRunnerMessage(message); + this.#tunnelWs.send(encoded); } - unregisterActor(actorId: string) { - this.#activeActors.delete(actorId); - - // Terminate all requests for this actor - const requests = this.#actorRequests.get(actorId); - if (requests) { - for (const requestId of requests) { - const pending = this.#pendingRequests.get(requestId); - if (pending) { - pending.reject(new Error(`Actor ${actorId} stopped`)); - this.#pendingRequests.delete(requestId); + #sendAck(requestId: tunnel.RequestId, messageId: tunnel.MessageId) { + if (!this.#tunnelWs || this.#tunnelWs.readyState !== WebSocket.OPEN) { + return; + } + + const message: tunnel.RunnerMessage = { + requestId, + messageId, + messageKind: { tag: "Ack", val: null }, + }; + + const encoded = tunnel.encodeRunnerMessage(message); + this.#tunnelWs.send(encoded); + } + + #startGarbageCollector() { + if (this.#gcInterval) { + clearInterval(this.#gcInterval); + } + + this.#gcInterval = setInterval(() => { + this.#gc(); + }, GC_INTERVAL); + } + + #gc() { + const now = Date.now(); + const messagesToDelete: string[] = []; + + for (const [messageId, pendingMessage] of this.#pendingMessages) { + // Check if message is older than timeout + if ( + now - pendingMessage.sentAt > MESSAGE_ACK_TIMEOUT + ) { + messagesToDelete.push(messageId); + + const requestIdStr = pendingMessage.requestIdStr; + + // Check if this is an HTTP request + const pendingRequest = + this.#actorPendingRequests.get(requestIdStr); + if (pendingRequest) { + // Reject the pending HTTP request + pendingRequest.reject( + new Error("Message acknowledgment timeout"), + ); + + // Close stream controller if it exists + if (pendingRequest.streamController) { + pendingRequest.streamController.error( + new Error("Message acknowledgment timeout"), + ); + } + + // Clean up from actorPendingRequests map + this.#actorPendingRequests.delete(requestIdStr); + } + + // Check if this is a WebSocket + const webSocket = this.#actorWebSockets.get(requestIdStr); + if (webSocket) { + // Close the WebSocket connection + webSocket.close(1000, "Message acknowledgment timeout"); + + // Clean up from actorWebSockets map + this.#actorWebSockets.delete(requestIdStr); } } - this.#actorRequests.delete(actorId); } - + + // Remove timed out messages + for (const messageId of messagesToDelete) { + this.#pendingMessages.delete(messageId); + console.warn(`Purged unacked message: ${messageId}`); + } + } + + unregisterActor(actor: ActorInstance) { + const actorId = actor.actorId; + + // Terminate all requests for this actor + for (const requestId of actor.requests) { + const pending = this.#actorPendingRequests.get(requestId); + if (pending) { + pending.reject(new Error(`Actor ${actorId} stopped`)); + this.#actorPendingRequests.delete(requestId); + } + } + actor.requests.clear(); + // Close all WebSockets for this actor - const webSockets = this.#actorWebSockets.get(actorId); - if (webSockets) { - for (const webSocketId of webSockets) { - const ws = this.#webSockets.get(webSocketId); - if (ws) { - ws.close(1000, "Actor stopped"); - this.#webSockets.delete(webSocketId); - } + for (const webSocketId of actor.webSockets) { + const ws = this.#actorWebSockets.get(webSocketId); + if (ws) { + ws.close(1000, "Actor stopped"); + this.#actorWebSockets.delete(webSocketId); } - this.#actorWebSockets.delete(actorId); } + actor.webSockets.clear(); } async #fetch(actorId: string, request: Request): Promise { // Validate actor exists - if (!this.#activeActors.has(actorId)) { - console.warn(`[TUNNEL] Ignoring request for unknown actor: ${actorId}`); + if (!this.#runner.hasActor(actorId)) { + console.warn( + `[TUNNEL] Ignoring request for unknown actor: ${actorId}`, + ); return new Response("Actor not found", { status: 404 }); } - - if (!this.#fetchHandler) { + + const fetchHandler = this.#runner.config.fetch(actorId, request); + + if (!fetchHandler) { return new Response("Not Implemented", { status: 501 }); } - - return this.#fetchHandler(actorId, request); + + return fetchHandler; } #connect() { if (this.#shutdown) return; try { - this.#ws = new WebSocket(this.#pegboardTunnelUrl, { + this.#tunnelWs = new WebSocket(this.#pegboardTunnelUrl, { headers: { "x-rivet-target": "tunnel", }, }); - this.#ws.binaryType = "arraybuffer"; + this.#tunnelWs.binaryType = "arraybuffer"; - this.#ws.addEventListener("open", () => { + this.#tunnelWs.addEventListener("open", () => { this.#reconnectAttempt = 0; - + if (this.#reconnectTimeout) { clearTimeout(this.#reconnectTimeout); this.#reconnectTimeout = undefined; } - this.#onConnected?.(); + this.#callbacks.onConnected(); }); - this.#ws.addEventListener("message", async (event) => { + this.#tunnelWs.addEventListener("message", async (event) => { try { await this.#handleMessage(event.data as ArrayBuffer); } catch (error) { @@ -164,12 +271,12 @@ export class Tunnel { } }); - this.#ws.addEventListener("error", (event) => { + this.#tunnelWs.addEventListener("error", (event) => { console.error("Tunnel WebSocket error:", event); }); - this.#ws.addEventListener("close", () => { - this.#onDisconnected?.(); + this.#tunnelWs.addEventListener("close", () => { + this.#callbacks.onDisconnected(); if (!this.#shutdown) { this.#scheduleReconnect(); @@ -192,9 +299,8 @@ export class Tunnel { multiplier: 2, jitter: true, }); - - this.#reconnectAttempt++; + this.#reconnectAttempt++; this.#reconnectTimeout = setTimeout(() => { this.#connect(); @@ -202,53 +308,97 @@ export class Tunnel { } async #handleMessage(data: ArrayBuffer) { - const message = tunnel.decodeTunnelMessage(new Uint8Array(data)); - - switch (message.body.tag) { - case "ToServerRequestStart": - await this.#handleRequestStart(message.body.val); - break; - case "ToServerRequestChunk": - await this.#handleRequestChunk(message.body.val); - break; - case "ToServerRequestFinish": - await this.#handleRequestFinish(message.body.val); - break; - case "ToServerWebSocketOpen": - await this.#handleWebSocketOpen(message.body.val); - break; - case "ToServerWebSocketMessage": - await this.#handleWebSocketMessage(message.body.val); - break; - case "ToServerWebSocketClose": - await this.#handleWebSocketClose(message.body.val); - break; - case "ToClientResponseStart": - this.#handleResponseStart(message.body.val); - break; - case "ToClientResponseChunk": - this.#handleResponseChunk(message.body.val); - break; - case "ToClientResponseFinish": - this.#handleResponseFinish(message.body.val); - break; - case "ToClientWebSocketOpen": - this.#handleWebSocketOpenResponse(message.body.val); - break; - case "ToClientWebSocketMessage": - this.#handleWebSocketMessageResponse(message.body.val); - break; - case "ToClientWebSocketClose": - this.#handleWebSocketCloseResponse(message.body.val); - break; + const message = tunnel.decodeRunnerMessage(new Uint8Array(data)); + + if (message.messageKind.tag === "Ack") { + // Mark pending message as acknowledged and remove it + const msgIdStr = bufferToString(message.messageId); + const pending = this.#pendingMessages.get(msgIdStr); + if (pending) { + this.#pendingMessages.delete(msgIdStr); + } + } else { + this.#sendAck(message.requestId, message.messageId); + switch (message.messageKind.tag) { + case "ToServerRequestStart": + await this.#handleRequestStart( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerRequestChunk": + await this.#handleRequestChunk( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerRequestAbort": + await this.#handleRequestAbort(message.requestId); + break; + case "ToServerWebSocketOpen": + await this.#handleWebSocketOpen( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerWebSocketMessage": + await this.#handleWebSocketMessage( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerWebSocketClose": + await this.#handleWebSocketClose( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientResponseStart": + this.#handleResponseStart( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientResponseChunk": + this.#handleResponseChunk( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientResponseAbort": + this.#handleResponseAbort(message.requestId); + break; + case "ToClientWebSocketOpen": + this.#handleWebSocketOpenResponse( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientWebSocketMessage": + this.#handleWebSocketMessageResponse( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientWebSocketClose": + this.#handleWebSocketCloseResponse( + message.requestId, + message.messageKind.val, + ); + break; + } } } - async #handleRequestStart(req: tunnel.ToServerRequestStart) { + async #handleRequestStart( + requestId: ArrayBuffer, + req: tunnel.ToServerRequestStart, + ) { // Track this request for the actor - const requests = this.#actorRequests.get(req.actorId); - if (requests) { - requests.add(req.requestId); + const requestIdStr = bufferToString(requestId); + const actor = this.#runner.getActor(req.actorId); + if (actor) { + actor.requests.add(requestIdStr); } try { @@ -271,12 +421,13 @@ export class Tunnel { const stream = new ReadableStream({ start: (controller) => { // Store controller for chunks - const existing = this.#pendingRequests.get(req.requestId); + const existing = + this.#actorPendingRequests.get(requestIdStr); if (existing) { existing.streamController = controller; existing.actorId = req.actorId; } else { - this.#pendingRequests.set(req.requestId, { + this.#actorPendingRequests.set(requestIdStr, { resolve: () => {}, reject: () => {}, streamController: controller, @@ -293,194 +444,193 @@ export class Tunnel { } as any); // Call fetch handler with validation - const response = await this.#fetch(req.actorId, streamingRequest); - await this.#sendResponse(req.requestId, response); + const response = await this.#fetch( + req.actorId, + streamingRequest, + ); + await this.#sendResponse(requestId, response); } else { // Non-streaming request const response = await this.#fetch(req.actorId, request); - await this.#sendResponse(req.requestId, response); + await this.#sendResponse(requestId, response); } } catch (error) { console.error("Error handling request:", error); - this.#sendResponseError(req.requestId, 500, "Internal Server Error"); + this.#sendResponseError(requestId, 500, "Internal Server Error"); } finally { // Clean up request tracking - const requests = this.#actorRequests.get(req.actorId); - if (requests) { - requests.delete(req.requestId); + const actor = this.#runner.getActor(req.actorId); + if (actor) { + actor.requests.delete(requestIdStr); } } } - async #handleRequestChunk(chunk: tunnel.ToServerRequestChunk) { - const pending = this.#pendingRequests.get(chunk.requestId); + async #handleRequestChunk( + requestId: ArrayBuffer, + chunk: tunnel.ToServerRequestChunk, + ) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { pending.streamController.enqueue(new Uint8Array(chunk.body)); + if (chunk.finish) { + pending.streamController.close(); + this.#actorPendingRequests.delete(requestIdStr); + } } } - async #handleRequestFinish(finish: tunnel.ToServerRequestFinish) { - const pending = this.#pendingRequests.get(finish.requestId); + async #handleRequestAbort(requestId: ArrayBuffer) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { - if (finish.reason === tunnel.StreamFinishReason.Complete) { - pending.streamController.close(); - } else { - pending.streamController.error(new Error("Request aborted")); - } + pending.streamController.error(new Error("Request aborted")); } - this.#pendingRequests.delete(finish.requestId); + this.#actorPendingRequests.delete(requestIdStr); } - async #sendResponse(requestId: bigint, response: Response) { + async #sendResponse(requestId: ArrayBuffer, response: Response) { // Always treat responses as non-streaming for now // In the future, we could detect streaming responses based on: // - Transfer-Encoding: chunked // - Content-Type: text/event-stream // - Explicit stream flag from the handler - + // Read the body first to get the actual content const body = response.body ? await response.arrayBuffer() : null; - + // Convert headers to map and add Content-Length if not present const headers = new Map(); response.headers.forEach((value, key) => { headers.set(key, value); }); - + // Add Content-Length header if we have a body and it's not already set if (body && !headers.has("content-length")) { headers.set("content-length", String(body.byteLength)); } // Send as non-streaming response - this.#send({ - body: { - tag: "ToClientResponseStart", - val: { - requestId, - status: response.status as tunnel.u16, - headers, - body: body || null, - stream: false, - }, + this.#sendMessage(requestId, { + tag: "ToClientResponseStart", + val: { + status: response.status as tunnel.u16, + headers, + body: body || null, + stream: false, }, }); } - #sendResponseError(requestId: bigint, status: number, message: string) { + #sendResponseError( + requestId: ArrayBuffer, + status: number, + message: string, + ) { const headers = new Map(); headers.set("content-type", "text/plain"); - this.#send({ - body: { - tag: "ToClientResponseStart", - val: { - requestId, - status: status as tunnel.u16, - headers, - body: new TextEncoder().encode(message).buffer as ArrayBuffer, - stream: false, - }, + this.#sendMessage(requestId, { + tag: "ToClientResponseStart", + val: { + status: status as tunnel.u16, + headers, + body: new TextEncoder().encode(message).buffer as ArrayBuffer, + stream: false, }, }); } - async #handleWebSocketOpen(open: tunnel.ToServerWebSocketOpen) { + async #handleWebSocketOpen( + requestId: ArrayBuffer, + open: tunnel.ToServerWebSocketOpen, + ) { + const webSocketId = bufferToString(requestId); // Validate actor exists - if (!this.#activeActors.has(open.actorId)) { - console.warn(`Ignoring WebSocket for unknown actor: ${open.actorId}`); + const actor = this.#runner.getActor(open.actorId); + if (!actor) { + console.warn( + `Ignoring WebSocket for unknown actor: ${open.actorId}`, + ); // Send close immediately - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: 1011, - reason: "Actor not found", - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: 1011, + reason: "Actor not found", }, }); return; } - if (!this.#websocketHandler) { + const websocketHandler = this.#runner.config.websocket; + + if (!websocketHandler) { console.error("No websocket handler configured for tunnel"); // Send close immediately - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: 1011, - reason: "Not Implemented", - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: 1011, + reason: "Not Implemented", }, }); return; } // Track this WebSocket for the actor - const webSockets = this.#actorWebSockets.get(open.actorId); - if (webSockets) { - webSockets.add(open.webSocketId); + if (actor) { + actor.webSockets.add(webSocketId); } try { // Create WebSocket adapter const adapter = new WebSocketTunnelAdapter( - open.webSocketId, + webSocketId, (data: ArrayBuffer | string, isBinary: boolean) => { // Send message through tunnel - const dataBuffer = typeof data === "string" - ? new TextEncoder().encode(data).buffer as ArrayBuffer - : data; - - this.#send({ - body: { - tag: "ToClientWebSocketMessage", - val: { - webSocketId: open.webSocketId, - data: dataBuffer, - binary: isBinary, - }, + const dataBuffer = + typeof data === "string" + ? (new TextEncoder().encode(data) + .buffer as ArrayBuffer) + : data; + + this.#sendMessage(requestId, { + tag: "ToClientWebSocketMessage", + val: { + data: dataBuffer, + binary: isBinary, }, }); }, (code?: number, reason?: string) => { // Send close through tunnel - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: code || null, - reason: reason || null, - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: code || null, + reason: reason || null, }, }); - + // Remove from map - this.#webSockets.delete(open.webSocketId); - + this.#actorWebSockets.delete(webSocketId); + // Clean up actor tracking - const webSockets = this.#actorWebSockets.get(open.actorId); - if (webSockets) { - webSockets.delete(open.webSocketId); + if (actor) { + actor.webSockets.delete(webSocketId); } - } + }, ); // Store adapter - this.#webSockets.set(open.webSocketId, adapter); + this.#actorWebSockets.set(webSocketId, adapter); // Send open confirmation - this.#send({ - body: { - tag: "ToClientWebSocketOpen", - val: { - webSocketId: open.webSocketId, - }, - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketOpen", + val: null, }); // Notify adapter that connection is open @@ -490,7 +640,10 @@ export class Tunnel { // Include original headers from the open message const headerInit: Record = {}; if (open.headers) { - for (const [k, v] of open.headers as ReadonlyMap) { + for (const [k, v] of open.headers as ReadonlyMap< + string, + string + >) { headerInit[k] = v; } } @@ -504,55 +657,68 @@ export class Tunnel { }); // Call websocket handler - await this.#websocketHandler(open.actorId, adapter, request); + await websocketHandler(open.actorId, adapter, request); } catch (error) { console.error("Error handling WebSocket open:", error); - + // Send close on error - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: 1011, - reason: "Server Error", - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: 1011, + reason: "Server Error", }, }); - - this.#webSockets.delete(open.webSocketId); - + + this.#actorWebSockets.delete(webSocketId); + // Clean up actor tracking - const webSockets = this.#actorWebSockets.get(open.actorId); - if (webSockets) { - webSockets.delete(open.webSocketId); + if (actor) { + actor.webSockets.delete(webSocketId); } } } - async #handleWebSocketMessage(msg: tunnel.ToServerWebSocketMessage) { - const adapter = this.#webSockets.get(msg.webSocketId); + async #handleWebSocketMessage( + requestId: ArrayBuffer, + msg: tunnel.ToServerWebSocketMessage, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { const data = msg.binary ? new Uint8Array(msg.data) : new TextDecoder().decode(new Uint8Array(msg.data)); - + adapter._handleMessage(data, msg.binary); } } - async #handleWebSocketClose(close: tunnel.ToServerWebSocketClose) { - const adapter = this.#webSockets.get(close.webSocketId); + async #handleWebSocketClose( + requestId: ArrayBuffer, + close: tunnel.ToServerWebSocketClose, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { - adapter._handleClose(close.code || undefined, close.reason || undefined); - this.#webSockets.delete(close.webSocketId); + adapter._handleClose( + close.code || undefined, + close.reason || undefined, + ); + this.#actorWebSockets.delete(webSocketId); } } - #handleResponseStart(resp: tunnel.ToClientResponseStart) { - const pending = this.#pendingRequests.get(resp.requestId); + #handleResponseStart( + requestId: ArrayBuffer, + resp: tunnel.ToClientResponseStart, + ) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (!pending) { - console.warn(`Received response for unknown request ${resp.requestId}`); + console.warn( + `Received response for unknown request ${requestIdStr}`, + ); return; } @@ -585,62 +751,84 @@ export class Tunnel { }); pending.resolve(response); - this.#pendingRequests.delete(resp.requestId); + this.#actorPendingRequests.delete(requestIdStr); } } - #handleResponseChunk(chunk: tunnel.ToClientResponseChunk) { - const pending = this.#pendingRequests.get(chunk.requestId); + #handleResponseChunk( + requestId: ArrayBuffer, + chunk: tunnel.ToClientResponseChunk, + ) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { pending.streamController.enqueue(new Uint8Array(chunk.body)); + if (chunk.finish) { + pending.streamController.close(); + this.#actorPendingRequests.delete(requestIdStr); + } } } - #handleResponseFinish(finish: tunnel.ToClientResponseFinish) { - const pending = this.#pendingRequests.get(finish.requestId); + #handleResponseAbort(requestId: ArrayBuffer) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { - if (finish.reason === tunnel.StreamFinishReason.Complete) { - pending.streamController.close(); - } else { - pending.streamController.error(new Error("Response aborted")); - } + pending.streamController.error(new Error("Response aborted")); } - this.#pendingRequests.delete(finish.requestId); + this.#actorPendingRequests.delete(requestIdStr); } - #handleWebSocketOpenResponse(open: tunnel.ToClientWebSocketOpen) { - const adapter = this.#webSockets.get(open.webSocketId); + #handleWebSocketOpenResponse( + requestId: ArrayBuffer, + open: tunnel.ToClientWebSocketOpen, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { adapter._handleOpen(); } } - #handleWebSocketMessageResponse(msg: tunnel.ToClientWebSocketMessage) { - const adapter = this.#webSockets.get(msg.webSocketId); + #handleWebSocketMessageResponse( + requestId: ArrayBuffer, + msg: tunnel.ToClientWebSocketMessage, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { const data = msg.binary ? new Uint8Array(msg.data) : new TextDecoder().decode(new Uint8Array(msg.data)); - + adapter._handleMessage(data, msg.binary); } } - #handleWebSocketCloseResponse(close: tunnel.ToClientWebSocketClose) { - const adapter = this.#webSockets.get(close.webSocketId); + #handleWebSocketCloseResponse( + requestId: ArrayBuffer, + close: tunnel.ToClientWebSocketClose, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { - adapter._handleClose(close.code || undefined, close.reason || undefined); - this.#webSockets.delete(close.webSocketId); + adapter._handleClose( + close.code || undefined, + close.reason || undefined, + ); + this.#actorWebSockets.delete(webSocketId); } } +} - #send(message: tunnel.TunnelMessage) { - if (!this.#ws || this.#ws.readyState !== WebSocket.OPEN) { - console.warn("Cannot send tunnel message, WebSocket not connected"); - return; - } +/** Converts a buffer to a string. Used for storing strings in a lookup map. */ +function bufferToString(buffer: ArrayBuffer): string { + return Buffer.from(buffer).toString("base64"); +} - const encoded = tunnel.encodeTunnelMessage(message); - this.#ws.send(encoded); - } +/** Generates a UUID as bytes. */ +function generateUuidBuffer(): ArrayBuffer { + const buffer = new Uint8Array(16); + uuidv4(undefined, buffer); + return buffer.buffer; } diff --git a/sdks/typescript/runner/src/websocket-tunnel-adapter.ts b/sdks/typescript/runner/src/websocket-tunnel-adapter.ts index bf097f8a5c..b2430d3238 100644 --- a/sdks/typescript/runner/src/websocket-tunnel-adapter.ts +++ b/sdks/typescript/runner/src/websocket-tunnel-adapter.ts @@ -2,7 +2,7 @@ // Implements a subset of the WebSocket interface for compatibility with runner code export class WebSocketTunnelAdapter { - #webSocketId: bigint; + #webSocketId: string; #readyState: number = 0; // CONNECTING #eventListeners: Map void>> = new Map(); #onopen: ((this: any, ev: any) => any) | null = null; @@ -24,7 +24,7 @@ export class WebSocketTunnelAdapter { }> = []; constructor( - webSocketId: bigint, + webSocketId: string, sendCallback: (data: ArrayBuffer | string, isBinary: boolean) => void, closeCallback: (code?: number, reason?: string) => void ) { diff --git a/sdks/typescript/tunnel-protocol/src/index.ts b/sdks/typescript/tunnel-protocol/src/index.ts index 0ea9a5b8ce..9e4b6ac609 100644 --- a/sdks/typescript/tunnel-protocol/src/index.ts +++ b/sdks/typescript/tunnel-protocol/src/index.ts @@ -1,30 +1,38 @@ +import assert from "node:assert" import * as bare from "@bare-ts/lib" const DEFAULT_CONFIG = /* @__PURE__ */ bare.Config({}) export type u16 = number -export type u64 = bigint -export type RequestId = u64 +export type RequestId = ArrayBuffer export function readRequestId(bc: bare.ByteCursor): RequestId { - return bare.readU64(bc) + return bare.readFixedData(bc, 16) } export function writeRequestId(bc: bare.ByteCursor, x: RequestId): void { - bare.writeU64(bc, x) + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) } -export type WebSocketId = u64 +/** + * UUIDv4 + */ +export type MessageId = ArrayBuffer -export function readWebSocketId(bc: bare.ByteCursor): WebSocketId { - return bare.readU64(bc) +export function readMessageId(bc: bare.ByteCursor): MessageId { + return bare.readFixedData(bc, 16) } -export function writeWebSocketId(bc: bare.ByteCursor, x: WebSocketId): void { - bare.writeU64(bc, x) +export function writeMessageId(bc: bare.ByteCursor, x: MessageId): void { + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) } +/** + * UUIDv4 + */ export type Id = string export function readId(bc: bare.ByteCursor): Id { @@ -35,38 +43,10 @@ export function writeId(bc: bare.ByteCursor, x: Id): void { bare.writeString(bc, x) } -export enum StreamFinishReason { - Complete = "Complete", - Abort = "Abort", -} - -export function readStreamFinishReason(bc: bare.ByteCursor): StreamFinishReason { - const offset = bc.offset - const tag = bare.readU8(bc) - switch (tag) { - case 0: - return StreamFinishReason.Complete - case 1: - return StreamFinishReason.Abort - default: { - bc.offset = offset - throw new bare.BareError(offset, "invalid tag") - } - } -} - -export function writeStreamFinishReason(bc: bare.ByteCursor, x: StreamFinishReason): void { - switch (x) { - case StreamFinishReason.Complete: { - bare.writeU8(bc, 0) - break - } - case StreamFinishReason.Abort: { - bare.writeU8(bc, 1) - break - } - } -} +/** + * MARK: Ack + */ +export type Ack = null function read0(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) @@ -103,10 +83,9 @@ function write1(bc: bare.ByteCursor, x: ArrayBuffer | null): void { } /** - * MARK: HTTP Request Forwarding + * MARK: HTTP */ export type ToServerRequestStart = { - readonly requestId: RequestId readonly actorId: Id readonly method: string readonly path: string @@ -117,7 +96,6 @@ export type ToServerRequestStart = { export function readToServerRequestStart(bc: bare.ByteCursor): ToServerRequestStart { return { - requestId: readRequestId(bc), actorId: readId(bc), method: bare.readString(bc), path: bare.readString(bc), @@ -128,7 +106,6 @@ export function readToServerRequestStart(bc: bare.ByteCursor): ToServerRequestSt } export function writeToServerRequestStart(bc: bare.ByteCursor, x: ToServerRequestStart): void { - writeRequestId(bc, x.requestId) writeId(bc, x.actorId) bare.writeString(bc, x.method) bare.writeString(bc, x.path) @@ -138,41 +115,25 @@ export function writeToServerRequestStart(bc: bare.ByteCursor, x: ToServerReques } export type ToServerRequestChunk = { - readonly requestId: RequestId readonly body: ArrayBuffer + readonly finish: boolean } export function readToServerRequestChunk(bc: bare.ByteCursor): ToServerRequestChunk { return { - requestId: readRequestId(bc), body: bare.readData(bc), + finish: bare.readBool(bc), } } export function writeToServerRequestChunk(bc: bare.ByteCursor, x: ToServerRequestChunk): void { - writeRequestId(bc, x.requestId) bare.writeData(bc, x.body) + bare.writeBool(bc, x.finish) } -export type ToServerRequestFinish = { - readonly requestId: RequestId - readonly reason: StreamFinishReason -} - -export function readToServerRequestFinish(bc: bare.ByteCursor): ToServerRequestFinish { - return { - requestId: readRequestId(bc), - reason: readStreamFinishReason(bc), - } -} - -export function writeToServerRequestFinish(bc: bare.ByteCursor, x: ToServerRequestFinish): void { - writeRequestId(bc, x.requestId) - writeStreamFinishReason(bc, x.reason) -} +export type ToServerRequestAbort = null export type ToClientResponseStart = { - readonly requestId: RequestId readonly status: u16 readonly headers: ReadonlyMap readonly body: ArrayBuffer | null @@ -181,7 +142,6 @@ export type ToClientResponseStart = { export function readToClientResponseStart(bc: bare.ByteCursor): ToClientResponseStart { return { - requestId: readRequestId(bc), status: bare.readU16(bc), headers: read0(bc), body: read1(bc), @@ -190,7 +150,6 @@ export function readToClientResponseStart(bc: bare.ByteCursor): ToClientResponse } export function writeToClientResponseStart(bc: bare.ByteCursor, x: ToClientResponseStart): void { - writeRequestId(bc, x.requestId) bare.writeU16(bc, x.status) write0(bc, x.headers) write1(bc, x.body) @@ -198,45 +157,29 @@ export function writeToClientResponseStart(bc: bare.ByteCursor, x: ToClientRespo } export type ToClientResponseChunk = { - readonly requestId: RequestId readonly body: ArrayBuffer + readonly finish: boolean } export function readToClientResponseChunk(bc: bare.ByteCursor): ToClientResponseChunk { return { - requestId: readRequestId(bc), body: bare.readData(bc), + finish: bare.readBool(bc), } } export function writeToClientResponseChunk(bc: bare.ByteCursor, x: ToClientResponseChunk): void { - writeRequestId(bc, x.requestId) bare.writeData(bc, x.body) + bare.writeBool(bc, x.finish) } -export type ToClientResponseFinish = { - readonly requestId: RequestId - readonly reason: StreamFinishReason -} - -export function readToClientResponseFinish(bc: bare.ByteCursor): ToClientResponseFinish { - return { - requestId: readRequestId(bc), - reason: readStreamFinishReason(bc), - } -} - -export function writeToClientResponseFinish(bc: bare.ByteCursor, x: ToClientResponseFinish): void { - writeRequestId(bc, x.requestId) - writeStreamFinishReason(bc, x.reason) -} +export type ToClientResponseAbort = null /** - * MARK: WebSocket Forwarding + * MARK: WebSocket */ export type ToServerWebSocketOpen = { readonly actorId: Id - readonly webSocketId: WebSocketId readonly path: string readonly headers: ReadonlyMap } @@ -244,7 +187,6 @@ export type ToServerWebSocketOpen = { export function readToServerWebSocketOpen(bc: bare.ByteCursor): ToServerWebSocketOpen { return { actorId: readId(bc), - webSocketId: readWebSocketId(bc), path: bare.readString(bc), headers: read0(bc), } @@ -252,27 +194,23 @@ export function readToServerWebSocketOpen(bc: bare.ByteCursor): ToServerWebSocke export function writeToServerWebSocketOpen(bc: bare.ByteCursor, x: ToServerWebSocketOpen): void { writeId(bc, x.actorId) - writeWebSocketId(bc, x.webSocketId) bare.writeString(bc, x.path) write0(bc, x.headers) } export type ToServerWebSocketMessage = { - readonly webSocketId: WebSocketId readonly data: ArrayBuffer readonly binary: boolean } export function readToServerWebSocketMessage(bc: bare.ByteCursor): ToServerWebSocketMessage { return { - webSocketId: readWebSocketId(bc), data: bare.readData(bc), binary: bare.readBool(bc), } } export function writeToServerWebSocketMessage(bc: bare.ByteCursor, x: ToServerWebSocketMessage): void { - writeWebSocketId(bc, x.webSocketId) bare.writeData(bc, x.data) bare.writeBool(bc, x.binary) } @@ -300,75 +238,54 @@ function write3(bc: bare.ByteCursor, x: string | null): void { } export type ToServerWebSocketClose = { - readonly webSocketId: WebSocketId readonly code: u16 | null readonly reason: string | null } export function readToServerWebSocketClose(bc: bare.ByteCursor): ToServerWebSocketClose { return { - webSocketId: readWebSocketId(bc), code: read2(bc), reason: read3(bc), } } export function writeToServerWebSocketClose(bc: bare.ByteCursor, x: ToServerWebSocketClose): void { - writeWebSocketId(bc, x.webSocketId) write2(bc, x.code) write3(bc, x.reason) } -export type ToClientWebSocketOpen = { - readonly webSocketId: WebSocketId -} - -export function readToClientWebSocketOpen(bc: bare.ByteCursor): ToClientWebSocketOpen { - return { - webSocketId: readWebSocketId(bc), - } -} - -export function writeToClientWebSocketOpen(bc: bare.ByteCursor, x: ToClientWebSocketOpen): void { - writeWebSocketId(bc, x.webSocketId) -} +export type ToClientWebSocketOpen = null export type ToClientWebSocketMessage = { - readonly webSocketId: WebSocketId readonly data: ArrayBuffer readonly binary: boolean } export function readToClientWebSocketMessage(bc: bare.ByteCursor): ToClientWebSocketMessage { return { - webSocketId: readWebSocketId(bc), data: bare.readData(bc), binary: bare.readBool(bc), } } export function writeToClientWebSocketMessage(bc: bare.ByteCursor, x: ToClientWebSocketMessage): void { - writeWebSocketId(bc, x.webSocketId) bare.writeData(bc, x.data) bare.writeBool(bc, x.binary) } export type ToClientWebSocketClose = { - readonly webSocketId: WebSocketId readonly code: u16 | null readonly reason: string | null } export function readToClientWebSocketClose(bc: bare.ByteCursor): ToClientWebSocketClose { return { - webSocketId: readWebSocketId(bc), code: read2(bc), reason: read3(bc), } } export function writeToClientWebSocketClose(bc: bare.ByteCursor, x: ToClientWebSocketClose): void { - writeWebSocketId(bc, x.webSocketId) write2(bc, x.code) write3(bc, x.reason) } @@ -376,16 +293,17 @@ export function writeToClientWebSocketClose(bc: bare.ByteCursor, x: ToClientWebS /** * MARK: Message */ -export type MessageBody = +export type MessageKind = + | { readonly tag: "Ack"; readonly val: Ack } /** * HTTP */ | { readonly tag: "ToServerRequestStart"; readonly val: ToServerRequestStart } | { readonly tag: "ToServerRequestChunk"; readonly val: ToServerRequestChunk } - | { readonly tag: "ToServerRequestFinish"; readonly val: ToServerRequestFinish } + | { readonly tag: "ToServerRequestAbort"; readonly val: ToServerRequestAbort } | { readonly tag: "ToClientResponseStart"; readonly val: ToClientResponseStart } | { readonly tag: "ToClientResponseChunk"; readonly val: ToClientResponseChunk } - | { readonly tag: "ToClientResponseFinish"; readonly val: ToClientResponseFinish } + | { readonly tag: "ToClientResponseAbort"; readonly val: ToClientResponseAbort } /** * WebSocket */ @@ -396,33 +314,35 @@ export type MessageBody = | { readonly tag: "ToClientWebSocketMessage"; readonly val: ToClientWebSocketMessage } | { readonly tag: "ToClientWebSocketClose"; readonly val: ToClientWebSocketClose } -export function readMessageBody(bc: bare.ByteCursor): MessageBody { +export function readMessageKind(bc: bare.ByteCursor): MessageKind { const offset = bc.offset const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "ToServerRequestStart", val: readToServerRequestStart(bc) } + return { tag: "Ack", val: null } case 1: - return { tag: "ToServerRequestChunk", val: readToServerRequestChunk(bc) } + return { tag: "ToServerRequestStart", val: readToServerRequestStart(bc) } case 2: - return { tag: "ToServerRequestFinish", val: readToServerRequestFinish(bc) } + return { tag: "ToServerRequestChunk", val: readToServerRequestChunk(bc) } case 3: - return { tag: "ToClientResponseStart", val: readToClientResponseStart(bc) } + return { tag: "ToServerRequestAbort", val: null } case 4: - return { tag: "ToClientResponseChunk", val: readToClientResponseChunk(bc) } + return { tag: "ToClientResponseStart", val: readToClientResponseStart(bc) } case 5: - return { tag: "ToClientResponseFinish", val: readToClientResponseFinish(bc) } + return { tag: "ToClientResponseChunk", val: readToClientResponseChunk(bc) } case 6: - return { tag: "ToServerWebSocketOpen", val: readToServerWebSocketOpen(bc) } + return { tag: "ToClientResponseAbort", val: null } case 7: - return { tag: "ToServerWebSocketMessage", val: readToServerWebSocketMessage(bc) } + return { tag: "ToServerWebSocketOpen", val: readToServerWebSocketOpen(bc) } case 8: - return { tag: "ToServerWebSocketClose", val: readToServerWebSocketClose(bc) } + return { tag: "ToServerWebSocketMessage", val: readToServerWebSocketMessage(bc) } case 9: - return { tag: "ToClientWebSocketOpen", val: readToClientWebSocketOpen(bc) } + return { tag: "ToServerWebSocketClose", val: readToServerWebSocketClose(bc) } case 10: - return { tag: "ToClientWebSocketMessage", val: readToClientWebSocketMessage(bc) } + return { tag: "ToClientWebSocketOpen", val: null } case 11: + return { tag: "ToClientWebSocketMessage", val: readToClientWebSocketMessage(bc) } + case 12: return { tag: "ToClientWebSocketClose", val: readToClientWebSocketClose(bc) } default: { bc.offset = offset @@ -431,65 +351,66 @@ export function readMessageBody(bc: bare.ByteCursor): MessageBody { } } -export function writeMessageBody(bc: bare.ByteCursor, x: MessageBody): void { +export function writeMessageKind(bc: bare.ByteCursor, x: MessageKind): void { switch (x.tag) { - case "ToServerRequestStart": { + case "Ack": { bare.writeU8(bc, 0) + break + } + case "ToServerRequestStart": { + bare.writeU8(bc, 1) writeToServerRequestStart(bc, x.val) break } case "ToServerRequestChunk": { - bare.writeU8(bc, 1) + bare.writeU8(bc, 2) writeToServerRequestChunk(bc, x.val) break } - case "ToServerRequestFinish": { - bare.writeU8(bc, 2) - writeToServerRequestFinish(bc, x.val) + case "ToServerRequestAbort": { + bare.writeU8(bc, 3) break } case "ToClientResponseStart": { - bare.writeU8(bc, 3) + bare.writeU8(bc, 4) writeToClientResponseStart(bc, x.val) break } case "ToClientResponseChunk": { - bare.writeU8(bc, 4) + bare.writeU8(bc, 5) writeToClientResponseChunk(bc, x.val) break } - case "ToClientResponseFinish": { - bare.writeU8(bc, 5) - writeToClientResponseFinish(bc, x.val) + case "ToClientResponseAbort": { + bare.writeU8(bc, 6) break } case "ToServerWebSocketOpen": { - bare.writeU8(bc, 6) + bare.writeU8(bc, 7) writeToServerWebSocketOpen(bc, x.val) break } case "ToServerWebSocketMessage": { - bare.writeU8(bc, 7) + bare.writeU8(bc, 8) writeToServerWebSocketMessage(bc, x.val) break } case "ToServerWebSocketClose": { - bare.writeU8(bc, 8) + bare.writeU8(bc, 9) writeToServerWebSocketClose(bc, x.val) break } case "ToClientWebSocketOpen": { - bare.writeU8(bc, 9) - writeToClientWebSocketOpen(bc, x.val) + bare.writeU8(bc, 10) break } case "ToClientWebSocketMessage": { - bare.writeU8(bc, 10) + bare.writeU8(bc, 11) writeToClientWebSocketMessage(bc, x.val) break } case "ToClientWebSocketClose": { - bare.writeU8(bc, 11) + bare.writeU8(bc, 12) writeToClientWebSocketClose(bc, x.val) break } @@ -497,35 +418,89 @@ export function writeMessageBody(bc: bare.ByteCursor, x: MessageBody): void { } /** - * Main tunnel message + * MARK: Message sent over tunnel WebSocket */ -export type TunnelMessage = { - readonly body: MessageBody +export type RunnerMessage = { + readonly requestId: RequestId + readonly messageId: MessageId + readonly messageKind: MessageKind } -export function readTunnelMessage(bc: bare.ByteCursor): TunnelMessage { +export function readRunnerMessage(bc: bare.ByteCursor): RunnerMessage { return { - body: readMessageBody(bc), + requestId: readRequestId(bc), + messageId: readMessageId(bc), + messageKind: readMessageKind(bc), } } -export function writeTunnelMessage(bc: bare.ByteCursor, x: TunnelMessage): void { - writeMessageBody(bc, x.body) +export function writeRunnerMessage(bc: bare.ByteCursor, x: RunnerMessage): void { + writeRequestId(bc, x.requestId) + writeMessageId(bc, x.messageId) + writeMessageKind(bc, x.messageKind) +} + +export function encodeRunnerMessage(x: RunnerMessage, config?: Partial): Uint8Array { + const fullConfig = config != null ? bare.Config(config) : DEFAULT_CONFIG + const bc = new bare.ByteCursor( + new Uint8Array(fullConfig.initialBufferLength), + fullConfig, + ) + writeRunnerMessage(bc, x) + return new Uint8Array(bc.view.buffer, bc.view.byteOffset, bc.offset) +} + +export function decodeRunnerMessage(bytes: Uint8Array): RunnerMessage { + const bc = new bare.ByteCursor(bytes, DEFAULT_CONFIG) + const result = readRunnerMessage(bc) + if (bc.offset < bc.view.byteLength) { + throw new bare.BareError(bc.offset, "remaining bytes") + } + return result +} + +/** + * MARK: Message sent over UPS + */ +export type PubSubMessage = { + readonly requestId: RequestId + readonly messageId: MessageId + /** + * Subject to send replies to. Only sent when opening a new request from gateway -> runner. + */ + readonly replyTo: string | null + readonly messageKind: MessageKind +} + +export function readPubSubMessage(bc: bare.ByteCursor): PubSubMessage { + return { + requestId: readRequestId(bc), + messageId: readMessageId(bc), + replyTo: read3(bc), + messageKind: readMessageKind(bc), + } +} + +export function writePubSubMessage(bc: bare.ByteCursor, x: PubSubMessage): void { + writeRequestId(bc, x.requestId) + writeMessageId(bc, x.messageId) + write3(bc, x.replyTo) + writeMessageKind(bc, x.messageKind) } -export function encodeTunnelMessage(x: TunnelMessage, config?: Partial): Uint8Array { +export function encodePubSubMessage(x: PubSubMessage, config?: Partial): Uint8Array { const fullConfig = config != null ? bare.Config(config) : DEFAULT_CONFIG const bc = new bare.ByteCursor( new Uint8Array(fullConfig.initialBufferLength), fullConfig, ) - writeTunnelMessage(bc, x) + writePubSubMessage(bc, x) return new Uint8Array(bc.view.buffer, bc.view.byteOffset, bc.offset) } -export function decodeTunnelMessage(bytes: Uint8Array): TunnelMessage { +export function decodePubSubMessage(bytes: Uint8Array): PubSubMessage { const bc = new bare.ByteCursor(bytes, DEFAULT_CONFIG) - const result = readTunnelMessage(bc) + const result = readPubSubMessage(bc) if (bc.offset < bc.view.byteLength) { throw new bare.BareError(bc.offset, "remaining bytes") } From 74e5db56c2c121c60091e27807b8a85f1f56d07c Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 8 Sep 2025 17:58:08 -0700 Subject: [PATCH 2/5] chore(ups): add postgres auto-reconnect --- Cargo.lock | 1 + packages/common/test-deps-docker/src/lib.rs | 36 +++ packages/common/universalpubsub/Cargo.toml | 1 + .../src/driver/postgres/mod.rs | 222 ++++++++++++++---- .../common/universalpubsub/tests/reconnect.rs | 217 +++++++++++++++++ 5 files changed, 430 insertions(+), 47 deletions(-) create mode 100644 packages/common/universalpubsub/tests/reconnect.rs diff --git a/Cargo.lock b/Cargo.lock index 800d4ca464..44bd7cd6e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6338,6 +6338,7 @@ dependencies = [ "rivet-error", "rivet-test-deps-docker", "rivet-ups-protocol", + "rivet-util", "serde", "serde_json", "sha2", diff --git a/packages/common/test-deps-docker/src/lib.rs b/packages/common/test-deps-docker/src/lib.rs index efa46a3152..2ed73f8600 100644 --- a/packages/common/test-deps-docker/src/lib.rs +++ b/packages/common/test-deps-docker/src/lib.rs @@ -73,6 +73,42 @@ impl DockerRunConfig { Ok(true) } + pub async fn restart(&self) -> Result<()> { + let container_id = self + .container_id + .as_ref() + .ok_or_else(|| anyhow!("No container ID found, container not started"))?; + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "restarting docker container" + ); + + let output = Command::new("docker") + .arg("restart") + .arg(container_id) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "Failed to restart container {}: {}", + self.container_name, + stderr + ); + } + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "container restarted successfully" + ); + + Ok(()) + } + pub fn container_id(&self) -> Option<&str> { self.container_id.as_deref() } diff --git a/packages/common/universalpubsub/Cargo.toml b/packages/common/universalpubsub/Cargo.toml index 3cdb8509b6..4ffddeb77c 100644 --- a/packages/common/universalpubsub/Cargo.toml +++ b/packages/common/universalpubsub/Cargo.toml @@ -14,6 +14,7 @@ deadpool-postgres.workspace = true futures-util.workspace = true rivet-error.workspace = true rivet-ups-protocol.workspace = true +rivet-util.workspace = true serde_json.workspace = true versioned-data-util.workspace = true serde.workspace = true diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index c2df9a36b4..21cf96f5f5 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use anyhow::*; use async_trait::async_trait; @@ -8,6 +8,8 @@ use base64::Engine; use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; +use rivet_util::backoff::Backoff; +use tokio::sync::{Mutex, broadcast}; use tokio_postgres::{AsyncMessage, NoTls}; use tracing::Instrument; @@ -17,13 +19,13 @@ use crate::pubsub::DriverOutput; #[derive(Clone)] struct Subscription { // Channel to send messages to this subscription - tx: tokio::sync::broadcast::Sender>, + tx: broadcast::Sender>, // Cancellation token shared by all subscribers of this subject token: tokio_util::sync::CancellationToken, } impl Subscription { - fn new(tx: tokio::sync::broadcast::Sender>) -> Self { + fn new(tx: broadcast::Sender>) -> Self { let token = tokio_util::sync::CancellationToken::new(); Self { tx, token } } @@ -48,8 +50,9 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize = #[derive(Clone)] pub struct PostgresDriver { pool: Arc, - client: Arc, + client: Arc>>>, subscriptions: Arc>>, + client_ready: tokio::sync::watch::Receiver, } impl PostgresDriver { @@ -76,48 +79,168 @@ impl PostgresDriver { let subscriptions: Arc>> = Arc::new(Mutex::new(HashMap::new())); - let subscriptions2 = subscriptions.clone(); + let client: Arc>>> = Arc::new(Mutex::new(None)); - let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?; - tokio::spawn(async move { - // NOTE: This loop will stop automatically when client is dropped - loop { - match poll_fn(|cx| conn.poll_message(cx)).await { - Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - if let Some(sub) = - subscriptions2.lock().unwrap().get(note.channel()).cloned() - { - let bytes = match BASE64.decode(note.payload()) { - std::result::Result::Ok(b) => b, - std::result::Result::Err(err) => { - tracing::error!(?err, "failed decoding base64"); - break; - } - }; - let _ = sub.tx.send(bytes); - } - } - Some(std::result::Result::Ok(_)) => { - // Ignore other async messages + // Create channel for client ready notifications + let (ready_tx, client_ready) = tokio::sync::watch::channel(false); + + // Spawn connection lifecycle task + tokio::spawn(Self::spawn_connection_lifecycle( + conn_str.clone(), + subscriptions.clone(), + client.clone(), + ready_tx, + )); + + let driver = Self { + pool: Arc::new(pool), + client, + subscriptions, + client_ready, + }; + + // Wait for initial connection to be established + driver.wait_for_client().await?; + + Ok(driver) + } + + /// Manages the connection lifecycle with automatic reconnection + async fn spawn_connection_lifecycle( + conn_str: String, + subscriptions: Arc>>, + client: Arc>>>, + ready_tx: tokio::sync::watch::Sender, + ) { + let mut backoff = Backoff::new(8, None, 1_000, 1_000); + + loop { + match tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await { + Result::Ok((new_client, conn)) => { + tracing::info!("postgres listen connection established"); + // Reset backoff on successful connection + backoff = Backoff::new(8, None, 1_000, 1_000); + + let new_client = Arc::new(new_client); + + // Update the client reference immediately + *client.lock().await = Some(new_client.clone()); + // Notify that client is ready + let _ = ready_tx.send(true); + + // Get channels to re-subscribe to + let channels: Vec = + subscriptions.lock().await.keys().cloned().collect(); + let needs_resubscribe = !channels.is_empty(); + + if needs_resubscribe { + tracing::debug!( + ?channels, + "will re-subscribe to channels after connection starts" + ); } - Some(std::result::Result::Err(err)) => { - tracing::error!(?err, "async postgres error"); - break; + + // Spawn a task to re-subscribe after a short delay + if needs_resubscribe { + let client_for_resub = new_client.clone(); + let channels_clone = channels.clone(); + tokio::spawn(async move { + tracing::debug!( + ?channels_clone, + "re-subscribing to channels after reconnection" + ); + for channel in &channels_clone { + if let Result::Err(e) = client_for_resub + .execute(&format!("LISTEN \"{}\"", channel), &[]) + .await + { + tracing::error!(?e, %channel, "failed to re-subscribe to channel"); + } else { + tracing::debug!(%channel, "successfully re-subscribed to channel"); + } + } + }); } - None => { - tracing::debug!("async postgres connection closed"); - break; + + // Poll the connection until it closes + Self::poll_connection(conn, subscriptions.clone()).await; + + // Clear the client reference on disconnect + *client.lock().await = None; + // Notify that client is disconnected + let _ = ready_tx.send(false); + } + Result::Err(e) => { + tracing::error!(?e, "failed to connect to postgres, retrying"); + backoff.tick().await; + } + } + } + } + + /// Polls the connection for notifications until it closes or errors + async fn poll_connection( + mut conn: tokio_postgres::Connection< + tokio_postgres::Socket, + tokio_postgres::tls::NoTlsStream, + >, + subscriptions: Arc>>, + ) { + loop { + match poll_fn(|cx| conn.poll_message(cx)).await { + Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { + tracing::trace!(channel = %note.channel(), "received notification"); + if let Some(sub) = subscriptions.lock().await.get(note.channel()).cloned() { + let bytes = match BASE64.decode(note.payload()) { + std::result::Result::Ok(b) => b, + std::result::Result::Err(err) => { + tracing::error!(?err, "failed decoding base64"); + continue; + } + }; + tracing::trace!(channel = %note.channel(), bytes_len = bytes.len(), "sending to broadcast channel"); + let _ = sub.tx.send(bytes); + } else { + tracing::warn!(channel = %note.channel(), "received notification for unknown channel"); } } + Some(std::result::Result::Ok(_)) => { + // Ignore other async messages + } + Some(std::result::Result::Err(err)) => { + tracing::error!(?err, "postgres connection error, reconnecting"); + break; // Exit loop to reconnect + } + None => { + tracing::warn!("postgres connection closed, reconnecting"); + break; // Exit loop to reconnect + } } - tracing::debug!("listen connection closed"); - }); + } + } - Ok(Self { - pool: Arc::new(pool), - client: Arc::new(client), - subscriptions, + /// Wait for the client to be connected + async fn wait_for_client(&self) -> Result> { + let mut ready_rx = self.client_ready.clone(); + tokio::time::timeout(tokio::time::Duration::from_secs(5), async { + loop { + // Subscribe to changed before attempting to access client + let changed_fut = ready_rx.changed(); + + // Check if client is already available + if let Some(client) = self.client.lock().await.clone() { + return Ok(client); + } + + // Wait for change, will return client if exists on next iteration + changed_fut + .await + .map_err(|_| anyhow!("connection lifecycle task ended"))?; + tracing::debug!("client does not exist immediately after receive ready"); + } }) + .await + .map_err(|_| anyhow!("timeout waiting for postgres client connection"))? } fn hash_subject(&self, subject: &str) -> String { @@ -147,7 +270,7 @@ impl PubSubDriver for PostgresDriver { // Check if we already have a subscription for this channel let (rx, drop_guard) = - if let Some(existing_sub) = self.subscriptions.lock().unwrap().get(&hashed).cloned() { + if let Some(existing_sub) = self.subscriptions.lock().await.get(&hashed).cloned() { // Reuse the existing broadcast channel let rx = existing_sub.tx.subscribe(); let drop_guard = existing_sub.token.clone().drop_guard(); @@ -160,13 +283,15 @@ impl PubSubDriver for PostgresDriver { // Register subscription self.subscriptions .lock() - .unwrap() + .await .insert(hashed.clone(), subscription.clone()); // Execute LISTEN command on the async client (for receiving notifications) // This only needs to be done once per channel + // Wait for client to be connected with retry logic + let client = self.wait_for_client().await?; let span = tracing::trace_span!("pg_listen"); - self.client + client .execute(&format!("LISTEN \"{hashed}\""), &[]) .instrument(span) .await?; @@ -179,13 +304,16 @@ impl PubSubDriver for PostgresDriver { tokio::spawn(async move { token_clone.cancelled().await; if tx_clone.receiver_count() == 0 { - let sql = format!("UNLISTEN \"{}\"", hashed_clone); - if let Err(err) = driver.client.execute(sql.as_str(), &[]).await { - tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel"); - } else { - tracing::trace!(%hashed_clone, "unlistened channel"); + let client = driver.client.lock().await.clone(); + if let Some(client) = client { + let sql = format!("UNLISTEN \"{}\"", hashed_clone); + if let Err(err) = client.execute(sql.as_str(), &[]).await { + tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel"); + } else { + tracing::trace!(%hashed_clone, "unlistened channel"); + } } - driver.subscriptions.lock().unwrap().remove(&hashed_clone); + driver.subscriptions.lock().await.remove(&hashed_clone); } }); diff --git a/packages/common/universalpubsub/tests/reconnect.rs b/packages/common/universalpubsub/tests/reconnect.rs new file mode 100644 index 0000000000..fa98767e95 --- /dev/null +++ b/packages/common/universalpubsub/tests/reconnect.rs @@ -0,0 +1,217 @@ +use anyhow::*; +use rivet_test_deps_docker::{TestDatabase, TestPubSub}; +use std::{sync::Arc, time::Duration}; +use universalpubsub::{NextOutput, PubSub, PublishOpts}; +use uuid::Uuid; + +fn setup_logging() { + let _ = tracing_subscriber::fmt() + .with_env_filter("debug") + .with_ansi(false) + .with_test_writer() + .try_init(); +} + +#[tokio::test] +async fn test_nats_driver_with_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (pubsub_config, docker_config) = TestPubSub::Nats.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + + let rivet_config::config::PubSub::Nats(nats) = pubsub_config else { + unreachable!(); + }; + + use std::str::FromStr; + let server_addrs = nats + .addresses + .iter() + .map(|addr| format!("nats://{addr}")) + .map(|url| async_nats::ServerAddr::from_str(url.as_ref())) + .collect::, _>>() + .unwrap(); + + let driver = universalpubsub::driver::nats::NatsDriver::connect( + async_nats::ConnectOptions::new(), + &server_addrs[..], + ) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); + + test_reconnect_inner(&pubsub, &docker).await; +} + +#[tokio::test] +async fn test_nats_driver_without_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (pubsub_config, docker_config) = TestPubSub::Nats.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + + let rivet_config::config::PubSub::Nats(nats) = pubsub_config else { + unreachable!(); + }; + + use std::str::FromStr; + let server_addrs = nats + .addresses + .iter() + .map(|addr| format!("nats://{addr}")) + .map(|url| async_nats::ServerAddr::from_str(url.as_ref())) + .collect::, _>>() + .unwrap(); + + let driver = universalpubsub::driver::nats::NatsDriver::connect( + async_nats::ConnectOptions::new(), + &server_addrs[..], + ) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); + + test_reconnect_inner(&pubsub, &docker).await; +} + +#[tokio::test] +async fn test_postgres_driver_with_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (db_config, docker_config) = TestDatabase::Postgres.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + let rivet_config::config::Database::Postgres(pg) = db_config else { + unreachable!(); + }; + let url = pg.url.read().clone(); + + let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, true) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); + + test_reconnect_inner(&pubsub, &docker).await; +} + +#[tokio::test] +async fn test_postgres_driver_without_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (db_config, docker_config) = TestDatabase::Postgres.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + let rivet_config::config::Database::Postgres(pg) = db_config else { + unreachable!(); + }; + let url = pg.url.read().clone(); + + let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, false) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); + + test_reconnect_inner(&pubsub, &docker).await; +} + +async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { + tracing::info!("testing reconnect functionality"); + + // Open subscription + let mut subscriber = pubsub.subscribe("test.reconnect").await.unwrap(); + tracing::info!("opened initial subscription"); + + // Test publish/receive message before restart + let message_before = b"message before restart"; + pubsub + .publish("test.reconnect", message_before, PublishOpts::broadcast()) + .await + .unwrap(); + pubsub.flush().await.unwrap(); + + match subscriber.next().await.unwrap() { + NextOutput::Message(msg) => { + assert_eq!( + msg.payload, message_before, + "message before restart should match" + ); + tracing::info!("received message before restart"); + } + NextOutput::Unsubscribed => { + panic!("unexpected unsubscribe before restart"); + } + } + + // Restart container + tracing::info!("restarting docker container"); + docker.restart().await.unwrap(); + + // Give the service time to come back up + tokio::time::sleep(Duration::from_secs(3)).await; + tracing::info!("docker container restarted"); + + // Test publish/receive message after restart + let message_after = b"message after restart"; + + // Retry logic for publish after restart since connection might need to reconnect + let mut retries = 0; + const MAX_RETRIES: u32 = 10; + loop { + match pubsub + .publish("test.reconnect", message_after, PublishOpts::broadcast()) + .await + { + Result::Ok(_) => { + tracing::info!("published message after restart"); + break; + } + Result::Err(e) if retries < MAX_RETRIES => { + retries += 1; + tracing::debug!(?e, retries, "failed to publish after restart, retrying"); + tokio::time::sleep(Duration::from_millis(500)).await; + } + Result::Err(e) => { + panic!("failed to publish after {} retries: {}", MAX_RETRIES, e); + } + } + } + + pubsub.flush().await.unwrap(); + + // Try to receive with timeout to handle reconnection delays + let receive_timeout = Duration::from_secs(10); + let receive_result = tokio::time::timeout(receive_timeout, subscriber.next()).await; + + match receive_result { + Result::Ok(Result::Ok(NextOutput::Message(msg))) => { + assert_eq!( + msg.payload, message_after, + "message after restart should match" + ); + tracing::info!("received message after restart - reconnection successful"); + } + Result::Ok(Result::Ok(NextOutput::Unsubscribed)) => { + panic!("unexpected unsubscribe after restart"); + } + Result::Ok(Result::Err(e)) => { + panic!("error receiving message after restart: {}", e); + } + Result::Err(_) => { + panic!("timeout receiving message after restart"); + } + } + + tracing::info!("reconnect test completed successfully"); +} From e29523a2c49904f573638386a1038549be9e368d Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 8 Sep 2025 20:18:37 -0700 Subject: [PATCH 3/5] chore(ups): handle edge cases with postgres listen/unlisten/notify when disconnected/reconnecting --- out/errors/ups.publish_failed.json | 5 + packages/common/test-deps-docker/src/lib.rs | 72 +++++++ .../src/driver/postgres/mod.rs | 166 +++++++++----- packages/common/universalpubsub/src/errors.rs | 2 + packages/common/universalpubsub/src/pubsub.rs | 31 ++- .../common/universalpubsub/tests/reconnect.rs | 202 +++++++++++++++--- 6 files changed, 395 insertions(+), 83 deletions(-) create mode 100644 out/errors/ups.publish_failed.json diff --git a/out/errors/ups.publish_failed.json b/out/errors/ups.publish_failed.json new file mode 100644 index 0000000000..b4c1ac0b42 --- /dev/null +++ b/out/errors/ups.publish_failed.json @@ -0,0 +1,5 @@ +{ + "code": "publish_failed", + "group": "ups", + "message": "Failed to publish message after retries" +} \ No newline at end of file diff --git a/packages/common/test-deps-docker/src/lib.rs b/packages/common/test-deps-docker/src/lib.rs index 2ed73f8600..644e57ad53 100644 --- a/packages/common/test-deps-docker/src/lib.rs +++ b/packages/common/test-deps-docker/src/lib.rs @@ -109,6 +109,78 @@ impl DockerRunConfig { Ok(()) } + pub async fn stop_container(&self) -> Result<()> { + let container_id = self + .container_id + .as_ref() + .ok_or_else(|| anyhow!("No container ID found, container not started"))?; + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "stopping docker container" + ); + + let output = Command::new("docker") + .arg("stop") + .arg(container_id) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "Failed to stop container {}: {}", + self.container_name, + stderr + ); + } + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "container stopped successfully" + ); + + Ok(()) + } + + pub async fn start_container(&self) -> Result<()> { + let container_id = self + .container_id + .as_ref() + .ok_or_else(|| anyhow!("No container ID found, container not started"))?; + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "starting docker container" + ); + + let output = Command::new("docker") + .arg("start") + .arg(container_id) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "Failed to start container {}: {}", + self.container_name, + stderr + ); + } + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "container started successfully" + ); + + Ok(()) + } + pub fn container_id(&self) -> Option<&str> { self.container_id.as_deref() } diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index 21cf96f5f5..3115a7ac72 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -112,21 +112,23 @@ impl PostgresDriver { client: Arc>>>, ready_tx: tokio::sync::watch::Sender, ) { - let mut backoff = Backoff::new(8, None, 1_000, 1_000); + let mut backoff = Backoff::default(); loop { match tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await { Result::Ok((new_client, conn)) => { tracing::info!("postgres listen connection established"); // Reset backoff on successful connection - backoff = Backoff::new(8, None, 1_000, 1_000); + backoff = Backoff::default(); let new_client = Arc::new(new_client); - // Update the client reference immediately - *client.lock().await = Some(new_client.clone()); - // Notify that client is ready - let _ = ready_tx.send(true); + // Spawn the polling task immediately + // This must be done before any operations on the client + let subscriptions_clone = subscriptions.clone(); + let poll_handle = tokio::spawn(async move { + Self::poll_connection(conn, subscriptions_clone).await; + }); // Get channels to re-subscribe to let channels: Vec = @@ -135,38 +137,41 @@ impl PostgresDriver { if needs_resubscribe { tracing::debug!( - ?channels, + channels=?channels.len(), "will re-subscribe to channels after connection starts" ); } - // Spawn a task to re-subscribe after a short delay + // Re-subscribe to channels if needs_resubscribe { - let client_for_resub = new_client.clone(); - let channels_clone = channels.clone(); - tokio::spawn(async move { - tracing::debug!( - ?channels_clone, - "re-subscribing to channels after reconnection" - ); - for channel in &channels_clone { - if let Result::Err(e) = client_for_resub - .execute(&format!("LISTEN \"{}\"", channel), &[]) - .await - { - tracing::error!(?e, %channel, "failed to re-subscribe to channel"); - } else { - tracing::debug!(%channel, "successfully re-subscribed to channel"); - } + tracing::debug!( + channels=?channels.len(), + "re-subscribing to channels after reconnection" + ); + for channel in &channels { + tracing::info!(?channel, "re-subscribing to channel"); + if let Result::Err(e) = new_client + .execute(&format!("LISTEN \"{}\"", channel), &[]) + .await + { + tracing::error!(?e, %channel, "failed to re-subscribe to channel"); + } else { + tracing::debug!(%channel, "successfully re-subscribed to channel"); } - }); + } } - // Poll the connection until it closes - Self::poll_connection(conn, subscriptions.clone()).await; + // Update the client reference and signal ready + // Do this AFTER re-subscribing to ensure LISTEN is complete + *client.lock().await = Some(new_client.clone()); + let _ = ready_tx.send(true); + + // Wait for the polling task to complete (when the connection closes) + let _ = poll_handle.await; // Clear the client reference on disconnect *client.lock().await = None; + // Notify that client is disconnected let _ = ready_tx.send(false); } @@ -208,12 +213,12 @@ impl PostgresDriver { // Ignore other async messages } Some(std::result::Result::Err(err)) => { - tracing::error!(?err, "postgres connection error, reconnecting"); - break; // Exit loop to reconnect + tracing::error!(?err, "postgres connection error"); + break; } None => { - tracing::warn!("postgres connection closed, reconnecting"); - break; // Exit loop to reconnect + tracing::warn!("postgres connection closed"); + break; } } } @@ -224,19 +229,16 @@ impl PostgresDriver { let mut ready_rx = self.client_ready.clone(); tokio::time::timeout(tokio::time::Duration::from_secs(5), async { loop { - // Subscribe to changed before attempting to access client - let changed_fut = ready_rx.changed(); - // Check if client is already available if let Some(client) = self.client.lock().await.clone() { return Ok(client); } - // Wait for change, will return client if exists on next iteration - changed_fut + // Wait for the ready signal to change + ready_rx + .changed() .await .map_err(|_| anyhow!("connection lifecycle task ended"))?; - tracing::debug!("client does not exist immediately after receive ready"); } }) .await @@ -288,13 +290,25 @@ impl PubSubDriver for PostgresDriver { // Execute LISTEN command on the async client (for receiving notifications) // This only needs to be done once per channel - // Wait for client to be connected with retry logic - let client = self.wait_for_client().await?; - let span = tracing::trace_span!("pg_listen"); - client - .execute(&format!("LISTEN \"{hashed}\""), &[]) - .instrument(span) - .await?; + // Try to LISTEN if client is available, but don't fail if disconnected + // The reconnection logic will handle re-subscribing + if let Some(client) = self.client.lock().await.clone() { + let span = tracing::trace_span!("pg_listen"); + match client + .execute(&format!("LISTEN \"{hashed}\""), &[]) + .instrument(span) + .await + { + Result::Ok(_) => { + tracing::debug!(%hashed, "successfully subscribed to channel"); + } + Result::Err(e) => { + tracing::warn!(?e, %hashed, "failed to LISTEN, will retry on reconnection"); + } + } + } else { + tracing::debug!(%hashed, "client not connected, will LISTEN on reconnection"); + } // Spawn a single cleanup task for this subscription waiting on its token let driver = self.clone(); @@ -333,14 +347,66 @@ impl PubSubDriver for PostgresDriver { // Encode payload to base64 and send NOTIFY let encoded = BASE64.encode(payload); - let conn = self.pool.get().await?; let hashed = self.hash_subject(subject); - let span = tracing::trace_span!("pg_notify"); - conn.execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[]) - .instrument(span) - .await?; - Ok(()) + tracing::debug!("attempting to get connection for publish"); + + // Wait for listen connection to be ready first if this channel has subscribers + // This ensures that if we're reconnecting, the LISTEN is re-registered before NOTIFY + if self.subscriptions.lock().await.contains_key(&hashed) { + self.wait_for_client().await?; + } + + // Retry getting a connection from the pool with backoff in case the connection is + // currently disconnected + let mut backoff = Backoff::default(); + let mut last_error = None; + + loop { + match self.pool.get().await { + Result::Ok(conn) => { + // Test the connection with a simple query before using it + match conn.execute("SELECT 1", &[]).await { + Result::Ok(_) => { + // Connection is good, use it for NOTIFY + let span = tracing::trace_span!("pg_notify"); + match conn + .execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[]) + .instrument(span) + .await + { + Result::Ok(_) => return Ok(()), + Result::Err(e) => { + tracing::debug!( + ?e, + "NOTIFY failed, retrying with new connection" + ); + last_error = Some(e.into()); + } + } + } + Result::Err(e) => { + tracing::debug!( + ?e, + "connection test failed, retrying with new connection" + ); + last_error = Some(e.into()); + } + } + } + Result::Err(e) => { + tracing::debug!(?e, "failed to get connection from pool, retrying"); + last_error = Some(e.into()); + } + } + + // Check if we should continue retrying + if !backoff.tick().await { + return Err( + last_error.unwrap_or_else(|| anyhow!("failed to publish after retries")) + ); + } + } } async fn flush(&self) -> Result<()> { diff --git a/packages/common/universalpubsub/src/errors.rs b/packages/common/universalpubsub/src/errors.rs index afab64db4a..69849d4592 100644 --- a/packages/common/universalpubsub/src/errors.rs +++ b/packages/common/universalpubsub/src/errors.rs @@ -6,4 +6,6 @@ use serde::{Deserialize, Serialize}; pub enum Ups { #[error("request_timeout", "Request timeout.")] RequestTimeout, + #[error("publish_failed", "Failed to publish message after retries")] + PublishFailed, } diff --git a/packages/common/universalpubsub/src/pubsub.rs b/packages/common/universalpubsub/src/pubsub.rs index fd24e41ff2..ac6ff27696 100644 --- a/packages/common/universalpubsub/src/pubsub.rs +++ b/packages/common/universalpubsub/src/pubsub.rs @@ -8,6 +8,8 @@ use tokio::sync::broadcast; use tokio::sync::{RwLock, oneshot}; use uuid::Uuid; +use rivet_util::backoff::Backoff; + use crate::chunking::{ChunkTracker, encode_chunk, split_payload_into_chunks}; use crate::driver::{PubSubDriverHandle, PublishOpts, SubscriberDriverHandle}; @@ -131,7 +133,8 @@ impl PubSub { break; } } else { - self.driver.publish(subject, &encoded).await?; + // Use backoff when publishing through the driver + self.publish_with_backoff(subject, &encoded).await?; } } Ok(()) @@ -174,7 +177,26 @@ impl PubSub { break; } } else { - self.driver.publish(subject, &encoded).await?; + // Use backoff when publishing through the driver + self.publish_with_backoff(subject, &encoded).await?; + } + } + Ok(()) + } + + async fn publish_with_backoff(&self, subject: &str, encoded: &[u8]) -> Result<()> { + let mut backoff = Backoff::default(); + loop { + match self.driver.publish(subject, encoded).await { + Result::Ok(_) => break, + Err(err) if !backoff.tick().await => { + tracing::info!(?err, "error publishing, cannot retry again"); + return Err(crate::errors::Ups::PublishFailed.build().into()); + } + Err(err) => { + tracing::info!(?err, "error publishing, retrying"); + // Continue retrying + } } } Ok(()) @@ -293,7 +315,10 @@ impl Subscriber { pub async fn next(&mut self) -> Result { loop { match self.driver.next().await? { - DriverOutput::Message { subject, payload } => { + DriverOutput::Message { + subject: _, + payload, + } => { // Process chunks let mut tracker = self.pubsub.chunk_tracker.lock().unwrap(); match tracker.process_chunk(&payload) { diff --git a/packages/common/universalpubsub/tests/reconnect.rs b/packages/common/universalpubsub/tests/reconnect.rs index fa98767e95..230110789b 100644 --- a/packages/common/universalpubsub/tests/reconnect.rs +++ b/packages/common/universalpubsub/tests/reconnect.rs @@ -43,7 +43,7 @@ async fn test_nats_driver_with_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); - test_reconnect_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker).await; } #[tokio::test] @@ -77,7 +77,7 @@ async fn test_nats_driver_without_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); - test_reconnect_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker).await; } #[tokio::test] @@ -100,7 +100,7 @@ async fn test_postgres_driver_with_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); - test_reconnect_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker).await; } #[tokio::test] @@ -123,7 +123,13 @@ async fn test_postgres_driver_without_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); + test_all_inner(&pubsub, &docker).await; +} + +async fn test_all_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { test_reconnect_inner(&pubsub, &docker).await; + test_publish_while_stopped(&pubsub, &docker).await; + test_subscribe_while_stopped(&pubsub, &docker).await; } async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { @@ -158,35 +164,14 @@ async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker:: tracing::info!("restarting docker container"); docker.restart().await.unwrap(); - // Give the service time to come back up - tokio::time::sleep(Duration::from_secs(3)).await; - tracing::info!("docker container restarted"); - // Test publish/receive message after restart + // + // This should retry under the hood, since the container will still be starting let message_after = b"message after restart"; - - // Retry logic for publish after restart since connection might need to reconnect - let mut retries = 0; - const MAX_RETRIES: u32 = 10; - loop { - match pubsub - .publish("test.reconnect", message_after, PublishOpts::broadcast()) - .await - { - Result::Ok(_) => { - tracing::info!("published message after restart"); - break; - } - Result::Err(e) if retries < MAX_RETRIES => { - retries += 1; - tracing::debug!(?e, retries, "failed to publish after restart, retrying"); - tokio::time::sleep(Duration::from_millis(500)).await; - } - Result::Err(e) => { - panic!("failed to publish after {} retries: {}", MAX_RETRIES, e); - } - } - } + pubsub + .publish("test.reconnect", message_after, PublishOpts::broadcast()) + .await + .unwrap(); pubsub.flush().await.unwrap(); @@ -215,3 +200,160 @@ async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker:: tracing::info!("reconnect test completed successfully"); } + +async fn test_publish_while_stopped( + pubsub: &PubSub, + docker: &rivet_test_deps_docker::DockerRunConfig, +) { + tracing::info!("testing publish while container stopped"); + + // 1. Subscribe + let mut subscriber = pubsub.subscribe("test.publish_stopped").await.unwrap(); + tracing::info!("opened subscription"); + + // 2. Stop container + tracing::info!("stopping docker container"); + docker.stop_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(2)).await; + + // 3. Publish while stopped (should queue/retry) + let message = b"message while stopped"; + let publish_handle = tokio::spawn({ + let pubsub = pubsub.clone(); + let message = message.to_vec(); + async move { + pubsub + .publish("test.publish_stopped", &message, PublishOpts::broadcast()) + .await + } + }); + + // 4. Start container + tokio::time::sleep(Duration::from_secs(3)).await; + tracing::info!("starting docker container"); + docker.start_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + // Wait for publish to complete + publish_handle.await.unwrap().unwrap(); + pubsub.flush().await.unwrap(); + + // 5. Receive message + tracing::info!("waiting for message"); + let receive_timeout = Duration::from_secs(5); + let receive_result = tokio::time::timeout(receive_timeout, subscriber.next()).await; + + match receive_result { + Result::Ok(Result::Ok(NextOutput::Message(msg))) => { + assert_eq!( + msg.payload, message, + "message published while stopped should be received" + ); + tracing::info!("received message published while stopped - reconnection successful"); + } + Result::Ok(Result::Ok(NextOutput::Unsubscribed)) => { + panic!("unexpected unsubscribe"); + } + Result::Ok(Result::Err(e)) => { + panic!("error receiving message: {}", e); + } + Result::Err(_) => { + panic!("timeout receiving message"); + } + } + + tracing::info!("publish while stopped test completed successfully"); +} + +async fn test_subscribe_while_stopped( + pubsub: &PubSub, + docker: &rivet_test_deps_docker::DockerRunConfig, +) { + tracing::info!("testing subscribe while container stopped"); + + // 1. Subscribe & test publish & unsubscribe + let mut subscriber = pubsub.subscribe("test.subscribe_stopped").await.unwrap(); + tracing::info!("opened initial subscription"); + + let test_message = b"test message"; + pubsub + .publish( + "test.subscribe_stopped", + test_message, + PublishOpts::broadcast(), + ) + .await + .unwrap(); + pubsub.flush().await.unwrap(); + + match subscriber.next().await.unwrap() { + NextOutput::Message(msg) => { + assert_eq!(msg.payload, test_message, "initial message should match"); + tracing::info!("received initial test message"); + } + NextOutput::Unsubscribed => { + panic!("unexpected unsubscribe"); + } + } + + drop(subscriber); // Drop to unsubscribe + tracing::info!("unsubscribed from initial subscription"); + + // 2. Stop container + tracing::info!("stopping docker container"); + docker.stop_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(2)).await; + + // 3. Subscribe while stopped + let subscribe_handle = tokio::spawn({ + let pubsub = pubsub.clone(); + async move { pubsub.subscribe("test.subscribe_stopped").await } + }); + + // 4. Start container + tokio::time::sleep(Duration::from_secs(3)).await; + tracing::info!("starting docker container"); + docker.start_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + // Wait for subscription to complete + let mut new_subscriber = subscribe_handle.await.unwrap().unwrap(); + tracing::info!("new subscription established after reconnect"); + + // 5. Publish message + let final_message = b"message after reconnect"; + pubsub + .publish( + "test.subscribe_stopped", + final_message, + PublishOpts::broadcast(), + ) + .await + .unwrap(); + pubsub.flush().await.unwrap(); + + // 6. Receive + let receive_timeout = Duration::from_secs(10); + let receive_result = tokio::time::timeout(receive_timeout, new_subscriber.next()).await; + + match receive_result { + Result::Ok(Result::Ok(NextOutput::Message(msg))) => { + assert_eq!( + msg.payload, final_message, + "message after reconnect should match" + ); + tracing::info!("received message after reconnect - subscription successful"); + } + Result::Ok(Result::Ok(NextOutput::Unsubscribed)) => { + panic!("unexpected unsubscribe"); + } + Result::Ok(Result::Err(e)) => { + panic!("error receiving message: {}", e); + } + Result::Err(_) => { + panic!("timeout receiving message"); + } + } + + tracing::info!("subscribe while stopped test completed successfully"); +} From d4b39b34edd2b6e9d6703c602c51930520931e67 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 8 Sep 2025 20:42:18 -0700 Subject: [PATCH 4/5] ci: switch to depot --- .github/workflows/release.yaml | 14 +++++++------- .github/workflows/rust.yml | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 961fbcbfb1..ce35afb4d2 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -79,22 +79,22 @@ jobs: matrix: include: - platform: linux - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: x86_64-unknown-linux-musl binary_ext: "" arch: x86_64 - platform: windows - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: x86_64-pc-windows-gnu binary_ext: ".exe" arch: x86_64 - platform: macos - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: x86_64-apple-darwin binary_ext: "" arch: x86_64 - platform: macos - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: aarch64-apple-darwin binary_ext: "" arch: aarch64 @@ -155,10 +155,10 @@ jobs: include: # TODO(RVT-4479): Add back ARM builder once manifest generation fixed # - platform: linux/arm64 - # runner: ubuntu-latest + # runner: depot-ubuntu-24.04 # arch_suffix: -arm64 - platform: linux/x86_64 - runner: ubuntu-latest + runner: depot-ubuntu-24.04 # TODO: Replace with appropriate arch_suffix when needed # arch_suffix: -amd64 arch_suffix: '' @@ -246,4 +246,4 @@ jobs: ./scripts/release/main.ts --version "${{ github.event.inputs.version }}" --completeCi else ./scripts/release/main.ts --version "${{ github.event.inputs.version }}" --no-latest --completeCi - fi \ No newline at end of file + fi diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1671a3fc41..745f5ae9cd 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -43,7 +43,7 @@ jobs: # clippy: # name: Clippy - # runs-on: ubuntu-latest + # runs-on: depot-ubuntu-24.04 # steps: # - uses: actions/checkout@v4 @@ -59,7 +59,7 @@ jobs: check: name: Check - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -77,7 +77,7 @@ jobs: test: name: Test - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - uses: actions/checkout@v4 From e0afd30b73ad2137e587f0412a6f80c30e19281d Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 9 Sep 2025 09:41:07 -0700 Subject: [PATCH 5/5] chore(runner): handle async actor stop --- sdks/typescript/runner/src/mod.ts | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 7d26a37925..2623d63848 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -127,7 +127,7 @@ export class Runner { // The server will send a StopActor command if it wants to fully stop } - stopActor(actorId: string, generation?: number) { + async stopActor(actorId: string, generation?: number) { const actor = this.#removeActor(actorId, generation); if (!actor) return; @@ -136,11 +136,14 @@ export class Runner { this.#tunnel.unregisterActor(actor); } - this.#sendActorStateUpdate(actorId, actor.generation, "stopped"); - - this.#config.onActorStop(actorId, actor.generation).catch((err) => { + // If onActorStop times out, Pegboard will handle this timeout with ACTOR_STOP_THRESHOLD_DURATION_MS + try { + await this.#config.onActorStop(actorId, actor.generation); + } catch (err) { console.error(`Error in onActorStop for actor ${actorId}:`, err); - }); + } + + this.#sendActorStateUpdate(actorId, actor.generation, "stopped"); } #stopAllActors() {