From 9d88491f2f5fd581aef1e82f391e653a887fea34 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sat, 6 Sep 2025 23:44:02 -0700 Subject: [PATCH 1/8] chore(pegboard): consolidate to single subscriber per gateway --- Cargo.lock | 6 +- packages/common/universalpubsub/Cargo.toml | 2 +- .../src/driver/postgres/mod.rs | 106 ++- packages/core/guard/server/Cargo.toml | 1 + packages/core/guard/server/src/lib.rs | 7 +- packages/core/guard/server/src/routing/mod.rs | 17 +- .../server/src/routing/pegboard_gateway.rs | 26 +- .../server/src/routing/pegboard_tunnel.rs | 5 +- .../core/guard/server/src/shared_state.rs | 32 + packages/core/pegboard-gateway/Cargo.toml | 6 +- packages/core/pegboard-gateway/src/lib.rs | 427 ++++------ .../core/pegboard-gateway/src/shared_state.rs | 286 +++++++ packages/core/pegboard-runner-ws/src/lib.rs | 45 +- packages/core/pegboard-tunnel/Cargo.toml | 3 +- packages/core/pegboard-tunnel/src/lib.rs | 779 ++++-------------- .../core/pegboard-tunnel/tests/integration.rs | 13 +- .../infra/engine/tests/actors_lifecycle.rs | 9 +- .../pegboard/src/ops/runner/get_by_key.rs | 51 ++ .../services/pegboard/src/ops/runner/mod.rs | 1 + .../services/pegboard/src/pubsub_subjects.rs | 68 +- pnpm-lock.yaml | 9 + sdks/rust/runner-protocol/src/protocol.rs | 2 - sdks/rust/runner-protocol/src/versioned.rs | 2 - sdks/rust/tunnel-protocol/build.rs | 89 +- sdks/rust/tunnel-protocol/src/versioned.rs | 66 +- sdks/schemas/runner-protocol/v1.bare | 2 - sdks/schemas/tunnel-protocol/v1.bare | 66 +- sdks/typescript/runner-protocol/src/index.ts | 157 ++-- sdks/typescript/runner/package.json | 1 + sdks/typescript/runner/src/mod.ts | 91 +- sdks/typescript/runner/src/tunnel.ts | 774 ++++++++++------- .../runner/src/websocket-tunnel-adapter.ts | 4 +- sdks/typescript/tunnel-protocol/src/index.ts | 293 +++---- 33 files changed, 1783 insertions(+), 1663 deletions(-) create mode 100644 packages/core/guard/server/src/shared_state.rs create mode 100644 packages/core/pegboard-gateway/src/shared_state.rs create mode 100644 packages/services/pegboard/src/ops/runner/get_by_key.rs diff --git a/Cargo.lock b/Cargo.lock index bcf2da44ee..800d4ca464 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3270,9 +3270,11 @@ dependencies = [ "rivet-guard-core", "rivet-tunnel-protocol", "rivet-util", + "thiserror 1.0.69", "tokio", "tokio-tungstenite", "universalpubsub", + "versioned-data-util", ] [[package]] @@ -3304,6 +3306,7 @@ version = "0.0.1" dependencies = [ "anyhow", "async-trait", + "bytes", "futures", "gasoline", "http-body-util", @@ -4340,6 +4343,7 @@ dependencies = [ "tracing", "udb-util", "universaldb", + "universalpubsub", "url", "uuid", ] @@ -6329,7 +6333,6 @@ dependencies = [ "base64 0.22.1", "deadpool-postgres", "futures-util", - "moka", "rivet-config", "rivet-env", "rivet-error", @@ -6342,6 +6345,7 @@ dependencies = [ "tempfile", "tokio", "tokio-postgres", + "tokio-util", "tracing", "tracing-subscriber", "uuid", diff --git a/packages/common/universalpubsub/Cargo.toml b/packages/common/universalpubsub/Cargo.toml index 37028f9995..3cdb8509b6 100644 --- a/packages/common/universalpubsub/Cargo.toml +++ b/packages/common/universalpubsub/Cargo.toml @@ -12,7 +12,6 @@ async-trait.workspace = true base64.workspace = true deadpool-postgres.workspace = true futures-util.workspace = true -moka.workspace = true rivet-error.workspace = true rivet-ups-protocol.workspace = true serde_json.workspace = true @@ -21,6 +20,7 @@ serde.workspace = true sha2.workspace = true tokio-postgres.workspace = true tokio.workspace = true +tokio-util.workspace = true tracing.workspace = true uuid.workspace = true diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index 99cb703706..c2df9a36b4 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -1,5 +1,6 @@ +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use anyhow::*; use async_trait::async_trait; @@ -7,7 +8,6 @@ use base64::Engine; use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; -use moka::future::Cache; use tokio_postgres::{AsyncMessage, NoTls}; use tracing::Instrument; @@ -18,6 +18,15 @@ use crate::pubsub::DriverOutput; struct Subscription { // Channel to send messages to this subscription tx: tokio::sync::broadcast::Sender>, + // Cancellation token shared by all subscribers of this subject + token: tokio_util::sync::CancellationToken, +} + +impl Subscription { + fn new(tx: tokio::sync::broadcast::Sender>) -> Self { + let token = tokio_util::sync::CancellationToken::new(); + Self { tx, token } + } } /// > In the default configuration it must be shorter than 8000 bytes @@ -40,7 +49,7 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize = pub struct PostgresDriver { pool: Arc, client: Arc, - subscriptions: Cache, + subscriptions: Arc>>, } impl PostgresDriver { @@ -65,8 +74,8 @@ impl PostgresDriver { .context("failed to create postgres pool")?; tracing::debug!("postgres pool created successfully"); - let subscriptions: Cache = - Cache::builder().initial_capacity(5).build(); + let subscriptions: Arc>> = + Arc::new(Mutex::new(HashMap::new())); let subscriptions2 = subscriptions.clone(); let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?; @@ -75,7 +84,9 @@ impl PostgresDriver { loop { match poll_fn(|cx| conn.poll_message(cx)).await { Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - if let Some(sub) = subscriptions2.get(note.channel()).await { + if let Some(sub) = + subscriptions2.lock().unwrap().get(note.channel()).cloned() + { let bytes = match BASE64.decode(note.payload()) { std::result::Result::Ok(b) => b, std::result::Result::Err(err) => { @@ -121,7 +132,7 @@ impl PostgresDriver { #[async_trait] impl PubSubDriver for PostgresDriver { async fn subscribe(&self, subject: &str) -> Result { - // TODO: To match NATS implementation, LIST must be pipelined (i.e. wait for the command + // TODO: To match NATS implementation, LISTEN must be pipelined (i.e. wait for the command // to reach the server, but not wait for it to respond). However, this has to ensure that // NOTIFY & LISTEN are called on the same connection (not diff connections in a pool) or // else there will be race conditions where messages might be published before @@ -135,33 +146,57 @@ impl PubSubDriver for PostgresDriver { let hashed = self.hash_subject(subject); // Check if we already have a subscription for this channel - let rx = if let Some(existing_sub) = self.subscriptions.get(&hashed).await { - // Reuse the existing broadcast channel - existing_sub.tx.subscribe() - } else { - // Create a new broadcast channel for this subject - let (tx, rx) = tokio::sync::broadcast::channel(1024); - let subscription = Subscription { tx: tx.clone() }; - - // Register subscription - self.subscriptions - .insert(hashed.clone(), subscription) - .await; - - // Execute LISTEN command on the async client (for receiving notifications) - // This only needs to be done once per channel - let span = tracing::trace_span!("pg_listen"); - self.client - .execute(&format!("LISTEN \"{hashed}\""), &[]) - .instrument(span) - .await?; - - rx - }; + let (rx, drop_guard) = + if let Some(existing_sub) = self.subscriptions.lock().unwrap().get(&hashed).cloned() { + // Reuse the existing broadcast channel + let rx = existing_sub.tx.subscribe(); + let drop_guard = existing_sub.token.clone().drop_guard(); + (rx, drop_guard) + } else { + // Create a new broadcast channel for this subject + let (tx, rx) = tokio::sync::broadcast::channel(1024); + let subscription = Subscription::new(tx.clone()); + + // Register subscription + self.subscriptions + .lock() + .unwrap() + .insert(hashed.clone(), subscription.clone()); + + // Execute LISTEN command on the async client (for receiving notifications) + // This only needs to be done once per channel + let span = tracing::trace_span!("pg_listen"); + self.client + .execute(&format!("LISTEN \"{hashed}\""), &[]) + .instrument(span) + .await?; + + // Spawn a single cleanup task for this subscription waiting on its token + let driver = self.clone(); + let hashed_clone = hashed.clone(); + let tx_clone = tx.clone(); + let token_clone = subscription.token.clone(); + tokio::spawn(async move { + token_clone.cancelled().await; + if tx_clone.receiver_count() == 0 { + let sql = format!("UNLISTEN \"{}\"", hashed_clone); + if let Err(err) = driver.client.execute(sql.as_str(), &[]).await { + tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel"); + } else { + tracing::trace!(%hashed_clone, "unlistened channel"); + } + driver.subscriptions.lock().unwrap().remove(&hashed_clone); + } + }); + + let drop_guard = subscription.token.clone().drop_guard(); + (rx, drop_guard) + }; Ok(Box::new(PostgresSubscriber { subject: subject.to_string(), - rx, + rx: Some(rx), + _drop_guard: drop_guard, })) } @@ -191,13 +226,18 @@ impl PubSubDriver for PostgresDriver { pub struct PostgresSubscriber { subject: String, - rx: tokio::sync::broadcast::Receiver>, + rx: Option>>, + _drop_guard: tokio_util::sync::DropGuard, } #[async_trait] impl SubscriberDriver for PostgresSubscriber { async fn next(&mut self) -> Result { - match self.rx.recv().await { + let rx = match self.rx.as_mut() { + Some(rx) => rx, + None => return Ok(DriverOutput::Unsubscribed), + }; + match rx.recv().await { std::result::Result::Ok(payload) => Ok(DriverOutput::Message { subject: self.subject.clone(), payload, diff --git a/packages/core/guard/server/Cargo.toml b/packages/core/guard/server/Cargo.toml index fbca49ce9f..ef8695c6c2 100644 --- a/packages/core/guard/server/Cargo.toml +++ b/packages/core/guard/server/Cargo.toml @@ -20,6 +20,7 @@ hyper-tungstenite.workspace = true tower.workspace = true udb-util.workspace = true universaldb.workspace = true +universalpubsub.workspace = true futures.workspace = true # TODO: Make this use workspace version hyper = "1.6.0" diff --git a/packages/core/guard/server/src/lib.rs b/packages/core/guard/server/src/lib.rs index 7f3c6e4626..f7533372e3 100644 --- a/packages/core/guard/server/src/lib.rs +++ b/packages/core/guard/server/src/lib.rs @@ -5,6 +5,7 @@ pub mod cache; pub mod errors; pub mod middleware; pub mod routing; +pub mod shared_state; pub mod tls; #[tracing::instrument(skip_all)] @@ -26,8 +27,12 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R tracing::warn!("crypto provider already installed in this process"); } + // Share shared context + let shared_state = shared_state::SharedState::new(ctx.ups()?); + shared_state.start().await?; + // Create handlers - let routing_fn = routing::create_routing_function(ctx.clone()); + let routing_fn = routing::create_routing_function(ctx.clone(), shared_state.clone()); let cache_key_fn = cache::create_cache_key_function(ctx.clone()); let middleware_fn = middleware::create_middleware_function(ctx.clone()); let cert_resolver = tls::create_cert_resolver(&ctx).await?; diff --git a/packages/core/guard/server/src/routing/mod.rs b/packages/core/guard/server/src/routing/mod.rs index b8b0d757f6..524d63075d 100644 --- a/packages/core/guard/server/src/routing/mod.rs +++ b/packages/core/guard/server/src/routing/mod.rs @@ -5,7 +5,7 @@ use gas::prelude::*; use hyper::header::HeaderName; use rivet_guard_core::RoutingFn; -use crate::errors; +use crate::{errors, shared_state::SharedState}; mod api_peer; mod api_public; @@ -17,13 +17,14 @@ pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-t /// Creates the main routing function that handles all incoming requests #[tracing::instrument(skip_all)] -pub fn create_routing_function(ctx: StandaloneCtx) -> RoutingFn { +pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> RoutingFn { Arc::new( move |hostname: &str, path: &str, port_type: rivet_guard_core::proxy_service::PortType, headers: &hyper::HeaderMap| { let ctx = ctx.clone(); + let shared_state = shared_state.clone(); Box::pin( async move { @@ -41,9 +42,15 @@ pub fn create_routing_function(ctx: StandaloneCtx) -> RoutingFn { return Ok(routing_output); } - if let Some(routing_output) = - pegboard_gateway::route_request(&ctx, target, host, path, headers) - .await? + if let Some(routing_output) = pegboard_gateway::route_request( + &ctx, + &shared_state, + target, + host, + path, + headers, + ) + .await? { return Ok(routing_output); } diff --git a/packages/core/guard/server/src/routing/pegboard_gateway.rs b/packages/core/guard/server/src/routing/pegboard_gateway.rs index fe8a5b2fe5..58088a1ac4 100644 --- a/packages/core/guard/server/src/routing/pegboard_gateway.rs +++ b/packages/core/guard/server/src/routing/pegboard_gateway.rs @@ -6,7 +6,7 @@ use hyper::header::HeaderName; use rivet_guard_core::proxy_service::{RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout}; use udb_util::{SERIALIZABLE, TxnExt}; -use crate::errors; +use crate::{errors, shared_state::SharedState}; const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10); pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); @@ -16,6 +16,7 @@ pub const X_RIVET_PORT: HeaderName = HeaderName::from_static("x-rivet-port"); #[tracing::instrument(skip_all)] pub async fn route_request( ctx: &StandaloneCtx, + shared_state: &SharedState, target: &str, _host: &str, path: &str, @@ -73,7 +74,7 @@ pub async fn route_request( let port_name = port_name.to_str()?; // Lookup actor - find_actor(ctx, actor_id, port_name, path).await + find_actor(ctx, shared_state, actor_id, port_name, path).await } struct FoundActor { @@ -86,6 +87,7 @@ struct FoundActor { #[tracing::instrument(skip_all, fields(%actor_id, %port_name, %path))] async fn find_actor( ctx: &StandaloneCtx, + shared_state: &SharedState, actor_id: Id, port_name: &str, path: &str, @@ -158,10 +160,10 @@ async fn find_actor( actor_ids: vec![actor_id], }); let res = tokio::time::timeout(Duration::from_secs(5), get_runner_fut).await??; - let runner_info = res.actors.into_iter().next().filter(|x| x.is_connectable); + let actor = res.actors.into_iter().next().filter(|x| x.is_connectable); - let runner_id = if let Some(runner_info) = runner_info { - runner_info.runner_id + let runner_id = if let Some(actor) = actor { + actor.runner_id } else { tracing::info!(?actor_id, "waiting for actor to become ready"); @@ -185,11 +187,23 @@ async fn find_actor( tracing::debug!(?actor_id, ?runner_id, "actor ready"); + // Get runner key from runner_id + let runner_key = ctx + .udb()? + .run(|tx, _mc| async move { + let txs = tx.subspace(pegboard::keys::subspace()); + let key_key = pegboard::keys::runner::KeyKey::new(runner_id); + txs.read_opt(&key_key, SERIALIZABLE).await + }) + .await? + .context("runner key not found")?; + // Return pegboard-gateway instance let gateway = pegboard_gateway::PegboardGateway::new( ctx.clone(), + shared_state.pegboard_gateway.clone(), actor_id, - runner_id, + runner_key, port_name.to_string(), ); Ok(Some(RoutingOutput::CustomServe(std::sync::Arc::new( diff --git a/packages/core/guard/server/src/routing/pegboard_tunnel.rs b/packages/core/guard/server/src/routing/pegboard_tunnel.rs index a3b6b0eaad..7d69150a03 100644 --- a/packages/core/guard/server/src/routing/pegboard_tunnel.rs +++ b/packages/core/guard/server/src/routing/pegboard_tunnel.rs @@ -12,13 +12,10 @@ pub async fn route_request( _host: &str, _path: &str, ) -> Result> { - // Check target if target != "tunnel" { return Ok(None); } - // Create pegboard-tunnel service instance - let tunnel = pegboard_tunnel::PegboardTunnelCustomServe::new(ctx.clone()).await?; - + let tunnel = pegboard_tunnel::PegboardTunnelCustomServe::new(ctx.clone()); Ok(Some(RoutingOutput::CustomServe(Arc::new(tunnel)))) } diff --git a/packages/core/guard/server/src/shared_state.rs b/packages/core/guard/server/src/shared_state.rs new file mode 100644 index 0000000000..71bc25939e --- /dev/null +++ b/packages/core/guard/server/src/shared_state.rs @@ -0,0 +1,32 @@ +use anyhow::*; +use gas::prelude::*; +use std::{ops::Deref, sync::Arc}; +use universalpubsub::PubSub; + +#[derive(Clone)] +pub struct SharedState(Arc); + +impl SharedState { + pub fn new(pubsub: PubSub) -> SharedState { + SharedState(Arc::new(SharedStateInner { + pegboard_gateway: pegboard_gateway::shared_state::SharedState::new(pubsub), + })) + } + + pub async fn start(&self) -> Result<()> { + self.pegboard_gateway.start().await?; + Ok(()) + } +} + +impl Deref for SharedState { + type Target = SharedStateInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub struct SharedStateInner { + pub pegboard_gateway: pegboard_gateway::shared_state::SharedState, +} diff --git a/packages/core/pegboard-gateway/Cargo.toml b/packages/core/pegboard-gateway/Cargo.toml index aeeb32bc53..6a3ddc446a 100644 --- a/packages/core/pegboard-gateway/Cargo.toml +++ b/packages/core/pegboard-gateway/Cargo.toml @@ -14,12 +14,14 @@ gas.workspace = true http-body-util.workspace = true hyper = "1.6" hyper-tungstenite.workspace = true -pegboard = { path = "../../services/pegboard" } +pegboard.workspace = true rand.workspace = true rivet-error.workspace = true -rivet-guard-core = { path = "../guard/core" } +rivet-guard-core.workspace = true rivet-tunnel-protocol.workspace = true rivet-util.workspace = true tokio-tungstenite.workspace = true tokio.workspace = true universalpubsub.workspace = true +versioned-data-util.workspace = true +thiserror.workspace = true diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index 09a960fe88..8d752eba42 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -4,56 +4,48 @@ use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use gas::prelude::*; use http_body_util::{BodyExt, Full}; -use hyper::{Request, Response, StatusCode, body::Incoming as BodyIncoming}; +use hyper::{Request, Response, StatusCode}; use hyper_tungstenite::HyperWebsocket; -use pegboard::pubsub_subjects::{ - TunnelHttpResponseSubject, TunnelHttpRunnerSubject, TunnelHttpWebSocketSubject, -}; -use rivet_error::*; use rivet_guard_core::{ custom_serve::CustomServeTrait, proxy_service::{ResponseBody, X_RIVET_ERROR}, request_context::RequestContext, }; use rivet_tunnel_protocol::{ - MessageBody, StreamFinishReason, ToServerRequestFinish, ToServerRequestStart, - ToServerWebSocketClose, ToServerWebSocketMessage, ToServerWebSocketOpen, TunnelMessage, - versioned, + MessageKind, ToServerRequestStart, ToServerWebSocketClose, ToServerWebSocketMessage, + ToServerWebSocketOpen, }; use rivet_util::serde::HashableMap; -use std::result::Result::Ok as ResultOk; -use std::{ - collections::HashMap, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, - time::Duration, -}; -use tokio::{ - sync::{Mutex, oneshot}, - time::timeout, -}; +use std::time::Duration; use tokio_tungstenite::tungstenite::Message; -use universalpubsub::NextOutput; + +use crate::shared_state::{SharedState, TunnelMessageData}; + +pub mod shared_state; const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2); pub struct PegboardGateway { ctx: StandaloneCtx, - request_counter: AtomicU64, + shared_state: SharedState, actor_id: Id, - runner_id: Id, + runner_key: String, port_name: String, } impl PegboardGateway { - pub fn new(ctx: StandaloneCtx, actor_id: Id, runner_id: Id, port_name: String) -> Self { + pub fn new( + ctx: StandaloneCtx, + shared_state: SharedState, + actor_id: Id, + runner_key: String, + port_name: String, + ) -> Self { Self { ctx, - request_counter: AtomicU64::new(0), + shared_state, actor_id, - runner_id, + runner_key, port_name, } } @@ -70,7 +62,7 @@ impl CustomServeTrait for PegboardGateway { match res { Result::Ok(x) => Ok(x), Err(err) => { - if is_tunnel_closed_error(&err) { + if is_tunnel_service_unavailable(&err) { // This will force the request to be retried with a new tunnel Ok(Response::builder() .status(StatusCode::SERVICE_UNAVAILABLE) @@ -96,7 +88,7 @@ impl CustomServeTrait for PegboardGateway { { Result::Ok(()) => std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()), Result::Err((client_ws, err)) => { - if is_tunnel_closed_error(&err) { + if is_tunnel_service_unavailable(&err) { Err(( client_ws, rivet_guard_core::errors::WebSocketServiceUnavailable.build(), @@ -124,13 +116,10 @@ impl PegboardGateway { .map_err(|_| anyhow!("invalid x-rivet-actor header"))? .to_string(); - // Generate request ID using atomic counter - let request_id = self.request_counter.fetch_add(1, Ordering::SeqCst); - // Extract request parts let mut headers = HashableMap::new(); for (name, value) in req.headers() { - if let ResultOk(value_str) = value.to_str() { + if let Result::Ok(value_str) = value.to_str() { headers.insert(name.to_string(), value_str.to_string()); } } @@ -149,9 +138,21 @@ impl PegboardGateway { .map_err(|e| anyhow!("failed to read body: {}", e))? .to_bytes(); - // Create tunnel message - let request_start = ToServerRequestStart { - request_id, + // Build subject to publish to + let tunnel_subject = pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new( + &self.runner_key, + &self.port_name, + ) + .to_string(); + + // Start listening for request responses + let (request_id, mut msg_rx) = self + .shared_state + .start_in_flight_request(tunnel_subject) + .await; + + // Start request + let message = MessageKind::ToServerRequestStart(ToServerRequestStart { actor_id: actor_id.clone(), method, path, @@ -162,117 +163,33 @@ impl PegboardGateway { Some(body_bytes.to_vec()) }, stream: false, - }; - - let message = TunnelMessage { - body: MessageBody::ToServerRequestStart(request_start), - }; + }); + self.shared_state.send_message(request_id, message).await?; + + // Wait for response + tracing::info!("starting response handler task"); + let response_start = loop { + let Some(msg) = msg_rx.recv().await else { + tracing::warn!("received no message response"); + return Err(RequestError::ServiceUnavailable.into()); + }; - // Serialize message - let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message)) - .map_err(|e| anyhow!("failed to serialize message: {}", e))?; - - // Build pubsub topic - let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); - let topic = tunnel_subject.to_string(); - - tracing::info!( - %topic, - ?self.runner_id, - %self.port_name, - ?request_id, - "publishing request to pubsub" - ); - - // Create response channel - let (response_tx, response_rx) = oneshot::channel(); - let response_map = Arc::new(Mutex::new(HashMap::new())); - response_map.lock().await.insert(request_id, response_tx); - - // Subscribe to response topic - let response_subject = - TunnelHttpResponseSubject::new(self.runner_id, &self.port_name, request_id); - let response_topic = response_subject.to_string(); - - tracing::info!( - ?response_topic, - ?request_id, - "subscribing to response topic" - ); - - let mut subscriber = self.ctx.ups()?.subscribe(&response_topic).await?; - - // Spawn task to handle response - let response_map_clone = response_map.clone(); - tokio::spawn(async move { - tracing::info!("starting response handler task"); - while let ResultOk(NextOutput::Message(msg)) = subscriber.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} - Err(err) => { - tracing::warn!(?err, "failed to ack gateway request response message") + match msg { + TunnelMessageData::Message(msg) => match msg { + MessageKind::ToClientResponseStart(response_start) => { + break response_start; } - }; - - tracing::info!( - payload_len = msg.payload.len(), - "received response from pubsub" - ); - if let ResultOk(tunnel_msg) = versioned::TunnelMessage::deserialize(&msg.payload) { - match tunnel_msg.body { - MessageBody::ToClientResponseStart(response_start) => { - tracing::info!(request_id = ?response_start.request_id, status = response_start.status, "received response from tunnel"); - if let Some(tx) = response_map_clone - .lock() - .await - .remove(&response_start.request_id) - { - tracing::info!(request_id = ?response_start.request_id, "sending response to handler"); - let _ = tx.send(response_start); - } else { - tracing::warn!(request_id = ?response_start.request_id, "no handler found for response"); - } - } - _ => { - tracing::warn!("received non-response message from pubsub"); - } + _ => { + tracing::warn!("received non-response message from pubsub"); } - } else { - tracing::error!("failed to deserialize response from pubsub"); + }, + TunnelMessageData::Timeout => { + tracing::warn!("tunnel message timeout"); + return Err(RequestError::ServiceUnavailable.into()); } } - tracing::info!("response handler task ended"); - }); - - // Publish request - self.ctx - .ups()? - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) - .await - .map_err(|e| anyhow!("failed to publish request: {}", e))?; - - // Send finish message - let finish_message = TunnelMessage { - body: MessageBody::ToServerRequestFinish(ToServerRequestFinish { - request_id, - reason: StreamFinishReason::Complete, - }), - }; - let finish_serialized = - versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(finish_message)) - .map_err(|e| anyhow!("failed to serialize finish message: {}", e))?; - self.ctx - .ups()? - .request_with_timeout(&topic, &finish_serialized, UPS_REQ_TIMEOUT) - .await - .map_err(|e| anyhow!("failed to publish finish message: {}", e))?; - - // Wait for response with timeout - let response_start = match timeout(Duration::from_secs(30), response_rx).await { - ResultOk(ResultOk(response)) => response, - _ => return Err(anyhow!("request timed out")), }; + tracing::info!("response handler task ended"); // Build HTTP response let mut response_builder = @@ -309,62 +226,69 @@ impl PegboardGateway { Err(err) => return Err((client_ws, err)), }; - // Generate WebSocket ID using atomic counter - let websocket_id = self.request_counter.fetch_add(1, Ordering::SeqCst); - // Extract headers let mut request_headers = HashableMap::new(); for (name, value) in headers { - if let ResultOk(value_str) = value.to_str() { + if let Result::Ok(value_str) = value.to_str() { request_headers.insert(name.to_string(), value_str.to_string()); } } - let ups = match self.ctx.ups() { - Result::Ok(u) => u, - Err(err) => return Err((client_ws, err.into())), - }; - - // Subscribe to messages from server before informing server that a client websocket is connecting to - // prevent race conditions. - let ws_subject = - TunnelHttpWebSocketSubject::new(self.runner_id, &self.port_name, websocket_id); - let response_topic = ws_subject.to_string(); - let mut subscriber = match ups.subscribe(&response_topic).await { - Result::Ok(sub) => sub, - Err(err) => return Err((client_ws, err.into())), - }; + // Build subject to publish to + let tunnel_subject = pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new( + &self.runner_key, + &self.port_name, + ) + .to_string(); - // Build pubsub topic - let tunnel_subject = TunnelHttpRunnerSubject::new(self.runner_id, &self.port_name); - let topic = tunnel_subject.to_string(); + // Start listening for WebSocket messages + let (request_id, mut msg_rx) = self + .shared_state + .start_in_flight_request(tunnel_subject.clone()) + .await; // Send WebSocket open message - let open_message = TunnelMessage { - body: MessageBody::ToServerWebSocketOpen(ToServerWebSocketOpen { - actor_id: actor_id.clone(), - web_socket_id: websocket_id, - path: path.to_string(), - headers: request_headers, - }), - }; - - let serialized = - match versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(open_message)) { - Result::Ok(s) => s, - Err(e) => { - return Err(( - client_ws, - anyhow!("failed to serialize websocket open: {}", e), - )); - } - }; + let open_message = MessageKind::ToServerWebSocketOpen(ToServerWebSocketOpen { + actor_id: actor_id.clone(), + path: path.to_string(), + headers: request_headers, + }); - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + if let Err(err) = self + .shared_state + .send_message(request_id, open_message) .await { - return Err((client_ws, err.into())); + return Err((client_ws, err)); + } + + // Wait for WebSocket open acknowledgment + let open_ack_received = loop { + let Some(msg) = msg_rx.recv().await else { + tracing::warn!("received no websocket open response"); + return Err((client_ws, RequestError::ServiceUnavailable.into())); + }; + + match msg { + TunnelMessageData::Message(MessageKind::ToClientWebSocketOpen) => { + break true; + } + TunnelMessageData::Message(MessageKind::ToClientWebSocketClose(close)) => { + tracing::info!(?close, "websocket closed before opening"); + return Err((client_ws, RequestError::ServiceUnavailable.into())); + } + TunnelMessageData::Timeout => { + tracing::warn!("websocket open timeout"); + return Err((client_ws, RequestError::ServiceUnavailable.into())); + } + _ => { + tracing::warn!("received unexpected message while waiting for websocket open"); + } + } + }; + + if !open_ack_received { + return Err((client_ws, anyhow!("failed to open websocket"))); } // Accept the WebSocket @@ -379,33 +303,30 @@ impl PegboardGateway { let (mut ws_sink, mut ws_stream) = ws_stream.split(); // Spawn task to forward messages from server to client + let mut msg_rx_for_task = msg_rx; tokio::spawn(async move { - while let ResultOk(NextOutput::Message(msg)) = subscriber.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} - Err(err) => { - tracing::warn!(?err, "failed to ack gateway websocket message") - } - }; - - if let ResultOk(tunnel_msg) = versioned::TunnelMessage::deserialize(&msg.payload) { - match tunnel_msg.body { - MessageBody::ToClientWebSocketMessage(ws_msg) => { - if ws_msg.web_socket_id == websocket_id { - let msg = if ws_msg.binary { - Message::Binary(ws_msg.data.into()) - } else { - Message::Text( - String::from_utf8_lossy(&ws_msg.data).into_owned().into(), - ) - }; - let _ = ws_sink.send(msg).await; - } + while let Some(msg) = msg_rx_for_task.recv().await { + match msg { + TunnelMessageData::Message(MessageKind::ToClientWebSocketMessage(ws_msg)) => { + let msg = if ws_msg.binary { + Message::Binary(ws_msg.data.into()) + } else { + Message::Text(String::from_utf8_lossy(&ws_msg.data).into_owned().into()) + }; + if let Err(e) = ws_sink.send(msg).await { + tracing::warn!(?e, "failed to send websocket message to client"); + break; } - MessageBody::ToClientWebSocketClose(_) => break, - _ => {} } + TunnelMessageData::Message(MessageKind::ToClientWebSocketClose(close)) => { + tracing::info!(?close, "server closed websocket"); + break; + } + TunnelMessageData::Timeout => { + tracing::warn!("websocket message timeout"); + break; + } + _ => {} } } }); @@ -414,25 +335,14 @@ impl PegboardGateway { let mut close_reason = None; while let Some(msg) = ws_stream.next().await { match msg { - ResultOk(Message::Binary(data)) => { - let ws_message = TunnelMessage { - body: MessageBody::ToServerWebSocketMessage(ToServerWebSocketMessage { - web_socket_id: websocket_id, + Result::Ok(Message::Binary(data)) => { + let ws_message = + MessageKind::ToServerWebSocketMessage(ToServerWebSocketMessage { data: data.into(), binary: true, - }), - }; - let serialized = match versioned::TunnelMessage::serialize( - versioned::TunnelMessage::V1(ws_message), - ) { - Result::Ok(s) => s, - Err(_) => break, - }; - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) - .await - { - if is_tunnel_closed_error(&err) { + }); + if let Err(err) = self.shared_state.send_message(request_id, ws_message).await { + if is_tunnel_service_unavailable(&err) { tracing::warn!("tunnel closed sending binary message"); close_reason = Some("Tunnel closed".to_string()); break; @@ -441,25 +351,14 @@ impl PegboardGateway { } } } - ResultOk(Message::Text(text)) => { - let ws_message = TunnelMessage { - body: MessageBody::ToServerWebSocketMessage(ToServerWebSocketMessage { - web_socket_id: websocket_id, + Result::Ok(Message::Text(text)) => { + let ws_message = + MessageKind::ToServerWebSocketMessage(ToServerWebSocketMessage { data: text.as_bytes().to_vec(), binary: false, - }), - }; - let serialized = match versioned::TunnelMessage::serialize( - versioned::TunnelMessage::V1(ws_message), - ) { - Result::Ok(s) => s, - Err(_) => break, - }; - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) - .await - { - if is_tunnel_closed_error(&err) { + }); + if let Err(err) = self.shared_state.send_message(request_id, ws_message).await { + if is_tunnel_service_unavailable(&err) { tracing::warn!("tunnel closed sending text message"); close_reason = Some("Tunnel closed".to_string()); break; @@ -468,32 +367,23 @@ impl PegboardGateway { } } } - ResultOk(Message::Close(_)) | Err(_) => break, + Result::Ok(Message::Close(_)) | Err(_) => break, _ => {} } } // Send WebSocket close message - let close_message = TunnelMessage { - body: MessageBody::ToServerWebSocketClose(ToServerWebSocketClose { - web_socket_id: websocket_id, - code: None, - reason: close_reason, - }), - }; - - let serialized = match versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1( - close_message, - )) { - Result::Ok(s) => s, - Err(_) => Vec::new(), - }; + let close_message = MessageKind::ToServerWebSocketClose(ToServerWebSocketClose { + code: None, + reason: close_reason, + }); - if let Err(err) = ups - .request_with_timeout(&topic, &serialized, UPS_REQ_TIMEOUT) + if let Err(err) = self + .shared_state + .send_message(request_id, close_message) .await { - if is_tunnel_closed_error(&err) { + if is_tunnel_service_unavailable(&err) { tracing::warn!("tunnel closed sending close message"); } else { tracing::error!(?err, "error sending close message"); @@ -504,14 +394,13 @@ impl PegboardGateway { } } +#[derive(thiserror::Error, Debug)] +enum RequestError { + #[error("service unavailable")] + ServiceUnavailable, +} + /// Determines if the tunnel is closed by if the UPS service is no longer responding. -fn is_tunnel_closed_error(err: &anyhow::Error) -> bool { - if let Some(err) = err.chain().find_map(|x| x.downcast_ref::()) - && err.group() == "ups" - && err.code() == "request_timeout" - { - true - } else { - false - } +fn is_tunnel_service_unavailable(err: &anyhow::Error) -> bool { + err.chain().any(|x| x.is::()) } diff --git a/packages/core/pegboard-gateway/src/shared_state.rs b/packages/core/pegboard-gateway/src/shared_state.rs new file mode 100644 index 0000000000..75b485679b --- /dev/null +++ b/packages/core/pegboard-gateway/src/shared_state.rs @@ -0,0 +1,286 @@ +use anyhow::*; +use gas::prelude::*; +use rivet_tunnel_protocol::{ + MessageId, MessageKind, PROTOCOL_VERSION, PubSubMessage, RequestId, versioned, +}; +use std::{ + collections::HashMap, + ops::Deref, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{Mutex, mpsc}; +use universalpubsub::{NextOutput, PubSub, PublishOpts, Subscriber}; +use versioned_data_util::OwnedVersionedData as _; + +const GC_INTERVAL: Duration = Duration::from_secs(60); +const MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(5); + +struct InFlightRequest { + /// UPS subject to send messages to for this request. + receiver_subject: String, + /// Sender for incoming messages to this request. + msg_tx: mpsc::Sender, + /// True once first message for this request has been sent (so runner learned reply_to). + opened: bool, +} + +struct PendingMessage { + request_id: RequestId, + send_instant: Instant, +} + +pub enum TunnelMessageData { + Message(MessageKind), + Timeout, +} + +pub struct SharedStateInner { + ups: PubSub, + receiver_subject: String, + requests_in_flight: Mutex>, + pending_messages: Mutex>, +} + +#[derive(Clone)] +pub struct SharedState(Arc); + +impl SharedState { + pub fn new(ups: PubSub) -> Self { + let gateway_id = Uuid::new_v4(); + let receiver_subject = + pegboard::pubsub_subjects::TunnelGatewayReceiverSubject::new(gateway_id).to_string(); + + Self(Arc::new(SharedStateInner { + ups, + receiver_subject, + requests_in_flight: Mutex::new(HashMap::new()), + pending_messages: Mutex::new(HashMap::new()), + })) + } + + pub async fn start(&self) -> Result<()> { + let sub = self.ups.subscribe(&self.receiver_subject).await?; + + let self_clone = self.clone(); + tokio::spawn(async move { self_clone.receiver(sub).await }); + + let self_clone = self.clone(); + tokio::spawn(async move { self_clone.gc().await }); + + Ok(()) + } + + pub async fn send_message( + &self, + request_id: RequestId, + message_kind: MessageKind, + ) -> Result<()> { + let message_id = Uuid::new_v4().as_bytes().clone(); + + // Get subject and whether this is the first message for this request + let (tunnel_receiver_subject, include_reply_to) = { + let mut requests_in_flight = self.requests_in_flight.lock().await; + if let Some(req) = requests_in_flight.get_mut(&request_id) { + let receiver_subject = req.receiver_subject.clone(); + let include_reply_to = !req.opened; + if include_reply_to { + // Mark as opened so subsequent messages skip reply_to + req.opened = true; + } + (receiver_subject, include_reply_to) + } else { + bail!("request not in flight") + } + }; + + // Save pending message + { + let mut pending_messages = self.pending_messages.lock().await; + pending_messages.insert( + message_id, + PendingMessage { + request_id, + send_instant: Instant::now(), + }, + ); + } + + // Send message + let message = PubSubMessage { + request_id, + message_id, + // Only send reply to subject on the first message for this request. This reduces + // overhead of subsequent messages. + reply_to: if include_reply_to { + Some(self.receiver_subject.clone()) + } else { + None + }, + message_kind, + }; + let message_serialized = versioned::PubSubMessage::latest(message) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + self.ups + .publish( + &tunnel_receiver_subject, + &message_serialized, + PublishOpts::one(), + ) + .await?; + + Ok(()) + } + + pub async fn start_in_flight_request( + &self, + receiver_subject: String, + ) -> (RequestId, mpsc::Receiver) { + let id = Uuid::new_v4().into_bytes(); + let (msg_tx, msg_rx) = mpsc::channel(128); + self.requests_in_flight.lock().await.insert( + id, + InFlightRequest { + receiver_subject, + msg_tx, + opened: false, + }, + ); + (id, msg_rx) + } + + async fn receiver(&self, mut sub: Subscriber) { + while let Result::Ok(NextOutput::Message(msg)) = sub.next().await { + tracing::info!( + payload_len = msg.payload.len(), + "received message from pubsub" + ); + + match versioned::PubSubMessage::deserialize_with_embedded_version(&msg.payload) { + Result::Ok(msg) => { + tracing::debug!( + ?msg.request_id, + ?msg.message_id, + "successfully deserialized message" + ); + if let MessageKind::Ack = &msg.message_kind { + // Handle ack message + + let mut pending_messages = self.pending_messages.lock().await; + if pending_messages.remove(&msg.message_id).is_none() { + tracing::warn!( + "pending message does not exist or ack received after message body" + ); + } + } else { + // Forward message to receiver + + // Send message to sender using request_id directly + let requests_in_flight = self.requests_in_flight.lock().await; + let Some(in_flight) = requests_in_flight.get(&msg.request_id) else { + tracing::debug!( + ?msg.request_id, + "in flight has already been disconnected" + ); + continue; + }; + tracing::debug!( + ?msg.request_id, + "forwarding message to request handler" + ); + let _ = in_flight + .msg_tx + .send(TunnelMessageData::Message(msg.message_kind)) + .await; + + // Send ack + let ups_clone = self.ups.clone(); + let receiver_subject = in_flight.receiver_subject.clone(); + let ack_message = PubSubMessage { + request_id: msg.request_id, + message_id: Uuid::new_v4().into_bytes(), + reply_to: None, + message_kind: MessageKind::Ack, + }; + let ack_message_serialized = + match versioned::PubSubMessage::latest(ack_message) + .serialize_with_embedded_version(PROTOCOL_VERSION) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize ack"); + continue; + } + }; + tokio::spawn(async move { + if let Result::Err(err) = ups_clone + .publish( + &receiver_subject, + &ack_message_serialized, + PublishOpts::one(), + ) + .await + { + tracing::warn!(?err, "failed to ack message") + } + }); + } + } + Result::Err(err) => { + tracing::error!(?err, "failed to parse message"); + } + } + } + } + + async fn gc(&self) { + let mut interval = tokio::time::interval(GC_INTERVAL); + loop { + interval.tick().await; + + let now = Instant::now(); + + // Purge unacked messages + { + let mut pending_messages = self.pending_messages.lock().await; + let mut removed_req_ids = Vec::new(); + pending_messages.retain(|_k, v| { + if now.duration_since(v.send_instant) > MESSAGE_ACK_TIMEOUT { + // Expired + removed_req_ids.push(v.request_id.clone()); + false + } else { + true + } + }); + + // Close in-flight messages + let requests_in_flight = self.requests_in_flight.lock().await; + for req_id in removed_req_ids { + if let Some(x) = requests_in_flight.get(&req_id) { + let _ = x.msg_tx.send(TunnelMessageData::Timeout); + } else { + tracing::warn!( + ?req_id, + "message expired for in flight that does not exist" + ); + } + } + } + + // Purge no longer in flight + { + let mut requests_in_flight = self.requests_in_flight.lock().await; + requests_in_flight.retain(|_k, v| !v.msg_tx.is_closed()); + } + } + } +} + +impl Deref for SharedState { + type Target = SharedStateInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/packages/core/pegboard-runner-ws/src/lib.rs b/packages/core/pegboard-runner-ws/src/lib.rs index 32de54f51d..576c7d567e 100644 --- a/packages/core/pegboard-runner-ws/src/lib.rs +++ b/packages/core/pegboard-runner-ws/src/lib.rs @@ -270,6 +270,7 @@ async fn build_connection( UrlData { protocol_version, namespace, + runner_key, }: UrlData, ) -> Result<(Id, Arc)> { let namespace = ctx @@ -300,9 +301,7 @@ async fn build_connection( .map_err(|err: anyhow::Error| WsError::InvalidPacket(err.to_string()).build())?; let (runner_id, workflow_id) = if let protocol::ToServer::Init { - runner_id, name, - key, version, total_slots, addresses_http, @@ -311,7 +310,16 @@ async fn build_connection( .. } = &packet { - let runner_id = if let Some(runner_id) = runner_id { + // Look up existing runner by key + let existing_runner = ctx + .op(pegboard::ops::runner::get_by_key::Input { + namespace_id: namespace.namespace_id, + name: name.clone(), + key: runner_key.clone(), + }) + .await?; + + let runner_id = if let Some(runner) = existing_runner.runner { // IMPORTANT: Before we spawn/get the workflow, we try to update the runner's last ping ts. // This ensures if the workflow is currently checking for expiry that it will not expire // (because we are about to send signals to it) and if it is already expired (but not @@ -319,7 +327,7 @@ async fn build_connection( let update_ping_res = ctx .op(pegboard::ops::runner::update_alloc_idx::Input { runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { - runner_id: *runner_id, + runner_id: runner.runner_id, action: Action::UpdatePing { rtt: 0 }, }], }) @@ -332,11 +340,14 @@ async fn build_connection( .map(|notif| matches!(notif.eligibility, RunnerEligibility::Expired)) .unwrap_or_default() { + // Runner expired, create a new one Id::new_v1(ctx.config().dc_label()) } else { - *runner_id + // Use existing runner + runner.runner_id } } else { + // No existing runner for this key, create a new one Id::new_v1(ctx.config().dc_label()) }; @@ -346,7 +357,7 @@ async fn build_connection( runner_id, namespace_id: namespace.namespace_id, name: name.clone(), - key: key.clone(), + key: runner_key.clone(), version: version.clone(), total_slots: *total_slots, @@ -790,19 +801,17 @@ async fn msg_thread_inner(ctx: &StandaloneCtx, conns: Arc>) struct UrlData { protocol_version: u16, namespace: String, + runner_key: String, } fn parse_url(addr: SocketAddr, uri: hyper::Uri) -> Result { let url = url::Url::parse(&format!("ws://{addr}{uri}"))?; - // Get protocol version from last path segment - let last_segment = url - .path_segments() - .context("invalid url")? - .last() - .context("no path segments")?; - ensure!(last_segment.starts_with('v'), "invalid protocol version"); - let protocol_version = last_segment[1..] + // Read protocol version from query parameters (required) + let protocol_version = url + .query_pairs() + .find_map(|(n, v)| (n == "protocol_version").then_some(v)) + .context("missing `protocol_version` query parameter")? .parse::() .context("invalid protocol version")?; @@ -813,9 +822,17 @@ fn parse_url(addr: SocketAddr, uri: hyper::Uri) -> Result { .context("missing `namespace` query parameter")? .to_string(); + // Read runner key from query parameters (required) + let runner_key = url + .query_pairs() + .find_map(|(n, v)| (n == "runner_key").then_some(v)) + .context("missing `runner_key` query parameter")? + .to_string(); + Ok(UrlData { protocol_version, namespace, + runner_key, }) } diff --git a/packages/core/pegboard-tunnel/Cargo.toml b/packages/core/pegboard-tunnel/Cargo.toml index ad8747ffd8..ef927c3559 100644 --- a/packages/core/pegboard-tunnel/Cargo.toml +++ b/packages/core/pegboard-tunnel/Cargo.toml @@ -8,6 +8,7 @@ edition.workspace = true [dependencies] anyhow.workspace = true async-trait.workspace = true +bytes.workspace = true futures.workspace = true gas.workspace = true http-body-util = "0.1" @@ -37,4 +38,4 @@ versioned-data-util.workspace = true rand.workspace = true rivet-cache.workspace = true rivet-pools.workspace = true -rivet-util.workspace = true \ No newline at end of file +rivet-util.workspace = true diff --git a/packages/core/pegboard-tunnel/src/lib.rs b/packages/core/pegboard-tunnel/src/lib.rs index 37ddc97cb8..a71d3dc62f 100644 --- a/packages/core/pegboard-tunnel/src/lib.rs +++ b/packages/core/pegboard-tunnel/src/lib.rs @@ -1,54 +1,29 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - use anyhow::*; use async_trait::async_trait; +use bytes::Bytes; use futures::{SinkExt, StreamExt}; use gas::prelude::*; use http_body_util::Full; -use hyper::body::{Bytes, Incoming as BodyIncoming}; -use hyper::{Request, Response, StatusCode}; -use hyper_tungstenite::tungstenite::Utf8Bytes as WsUtf8Bytes; -use hyper_tungstenite::tungstenite::protocol::frame::CloseFrame as WsCloseFrame; -use hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode as WsCloseCode; +use hyper::{Response, StatusCode}; use hyper_tungstenite::{HyperWebsocket, tungstenite::Message as WsMessage}; -use pegboard::pubsub_subjects::{ - TunnelHttpResponseSubject, TunnelHttpRunnerSubject, TunnelHttpWebSocketSubject, +use rivet_guard_core::{ + custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +}; +use rivet_tunnel_protocol::{ + MessageKind, PROTOCOL_VERSION, PubSubMessage, RequestId, RunnerMessage, versioned, }; -use rivet_guard_core::custom_serve::CustomServeTrait; -use rivet_guard_core::proxy_service::ResponseBody; -use rivet_guard_core::request_context::RequestContext; -use rivet_pools::Pools; -use rivet_tunnel_protocol::{MessageBody, TunnelMessage, versioned}; -use rivet_util::Id; -use std::net::SocketAddr; -use tokio::net::TcpListener; +use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; -use tokio::sync::RwLock; -use tokio_tungstenite::accept_async; -use tracing::{error, info}; -use universalpubsub::pubsub::NextOutput; - -const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2); - -struct RunnerConnection { - _runner_id: Id, - _port_name: String, -} - -type Connections = Arc>>>; +use universalpubsub::{PublishOpts, pubsub::NextOutput}; +use versioned_data_util::OwnedVersionedData as _; pub struct PegboardTunnelCustomServe { ctx: StandaloneCtx, - connections: Connections, } impl PegboardTunnelCustomServe { - pub async fn new(ctx: StandaloneCtx) -> Result { - let connections = Arc::new(RwLock::new(HashMap::new())); - - Ok(Self { ctx, connections }) + pub fn new(ctx: StandaloneCtx) -> Self { + Self { ctx } } } @@ -82,654 +57,258 @@ impl CustomServeTrait for PegboardTunnelCustomServe { Result::Ok(u) => u, Err(e) => return Err((client_ws, e.into())), }; - let connections = self.connections.clone(); - // Extract runner_id from query parameters - let runner_id = if let std::result::Result::Ok(url) = - url::Url::parse(&format!("ws://placeholder/{path}")) + // Parse URL to extract runner_id and protocol version + let url = match url::Url::parse(&format!("ws://placeholder/{path}")) { + Result::Ok(u) => u, + Err(e) => return Err((client_ws, e.into())), + }; + + // Extract runner_key from query parameters (required) + let runner_key = match url + .query_pairs() + .find_map(|(n, v)| (n == "runner_key").then_some(v)) + { + Some(key) => key.to_string(), + None => { + return Err((client_ws, anyhow!("runner_key query parameter is required"))); + } + }; + + // Extract protocol version from query parameters (required) + let protocol_version = match url + .query_pairs() + .find_map(|(n, v)| (n == "protocol_version").then_some(v)) + .as_ref() + .and_then(|v| v.parse::().ok()) { - url.query_pairs() - .find_map(|(n, v)| (n == "runner_id").then_some(v)) - .as_ref() - .and_then(|id| Id::parse(id).ok()) - .unwrap_or(Id::nil()) - } else { - Id::nil() + Some(version) => version, + None => { + return Err(( + client_ws, + anyhow!("protocol_version query parameter is required and must be a valid u16"), + )); + } }; let port_name = "main".to_string(); // Use "main" as default port name - info!( - ?runner_id, + tracing::info!( + ?runner_key, ?port_name, + ?protocol_version, ?path, "tunnel WebSocket connection established" ); - let connection_id = Id::nil(); - // Subscribe to pubsub topic for this runner before accepting the client websocket so // that failures can be retried by the proxy. - let topic = TunnelHttpRunnerSubject::new(runner_id, &port_name).to_string(); - info!(%topic, ?runner_id, "subscribing to pubsub topic"); - + let topic = + pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new(&runner_key, &port_name) + .to_string(); + tracing::info!(%topic, ?runner_key, "subscribing to runner receiver topic"); let mut sub = match ups.subscribe(&topic).await { Result::Ok(s) => s, Err(e) => return Err((client_ws, e.into())), }; + // Accept WS let ws_stream = match client_ws.await { Result::Ok(ws) => ws, Err(e) => { // Handshake already in progress; cannot retry safely here - error!(error=?e, "client websocket await failed"); + tracing::error!(error=?e, "client websocket await failed"); return std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()); } }; - // Split WebSocket stream into read and write halves let (ws_write, mut ws_read) = ws_stream.split(); let ws_write = Arc::new(tokio::sync::Mutex::new(ws_write)); - // Store connection - let connection = Arc::new(RunnerConnection { - _runner_id: runner_id, - _port_name: port_name.clone(), - }); + struct ActiveRequest { + /// Subject to send replies to. + reply_to: String, + } - connections - .write() - .await - .insert(connection_id, connection.clone()); + // Active HTTP & WebSocket requests. They are separate but use the same mechanism to + // maintain state. + let active_requests = Arc::new(Mutex::new(HashMap::::new())); - // Handle bidirectional message forwarding + // Forward pubsub -> WebSocket let ws_write_pubsub_to_ws = ws_write.clone(); - let connections_clone = connections.clone(); let ups_clone = ups.clone(); - - // Task for forwarding pubsub -> WebSocket + let active_requests_clone = active_requests.clone(); let pubsub_to_ws = tokio::spawn(async move { - info!("starting pubsub to WebSocket forwarding task"); - while let ::std::result::Result::Ok(NextOutput::Message(msg)) = sub.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} + while let Result::Ok(NextOutput::Message(ups_msg)) = sub.next().await { + tracing::info!( + payload_len = ups_msg.payload.len(), + "received message from pubsub, forwarding to WebSocket" + ); + + // Parse message + let msg = match versioned::PubSubMessage::deserialize_with_embedded_version( + &ups_msg.payload, + ) { + Result::Ok(x) => x, Err(err) => { - tracing::warn!(?err, "failed to ack gateway request response message") + tracing::error!(?err, "failed to parse tunnel message"); + continue; } }; - info!( - payload_len = msg.payload.len(), - "received message from pubsub, forwarding to WebSocket" - ); + // Save active request + if let Some(reply_to) = msg.reply_to { + let mut active_requests = active_requests_clone.lock().await; + active_requests.insert(msg.request_id, ActiveRequest { reply_to }); + } + + // If terminal, remove active request tracking + if is_message_kind_request_close(&msg.message_kind) { + let mut active_requests = active_requests_clone.lock().await; + active_requests.remove(&msg.request_id); + } + // Forward raw message to WebSocket - let ws_msg = WsMessage::Binary(msg.payload.to_vec().into()); + let tunnel_msg = match versioned::RunnerMessage::latest(RunnerMessage { + request_id: msg.request_id, + message_id: msg.message_id, + message_kind: msg.message_kind, + }) + .serialize_version(protocol_version) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize tunnel message"); + continue; + } + }; + let ws_msg = WsMessage::Binary(tunnel_msg.into()); { let mut stream = ws_write_pubsub_to_ws.lock().await; if let Err(e) = stream.send(ws_msg).await { - error!(?e, "failed to send message to WebSocket"); + tracing::error!(?e, "failed to send message to WebSocket"); break; } } } - info!("pubsub to WebSocket forwarding task ended"); + tracing::info!("pubsub to WebSocket forwarding task ended"); }); - // Task for forwarding WebSocket -> pubsub - let ws_write_ws_to_pubsub = ws_write.clone(); + // Forward WebSocket -> pubsub + let active_requests_clone = active_requests.clone(); + let runner_key_clone = runner_key.clone(); let ws_to_pubsub = tokio::spawn(async move { - info!("starting WebSocket to pubsub forwarding task"); + tracing::info!("starting WebSocket to pubsub forwarding task"); while let Some(msg) = ws_read.next().await { match msg { - ::std::result::Result::Ok(WsMessage::Binary(data)) => { - info!( + Result::Ok(WsMessage::Binary(data)) => { + tracing::info!( data_len = data.len(), "received binary message from WebSocket" ); - // Parse the tunnel message to extract request_id - match versioned::TunnelMessage::deserialize(&data) { - ::std::result::Result::Ok(tunnel_msg) => { - // Handle different message types - match &tunnel_msg.body { - MessageBody::ToClientResponseStart(resp) => { - info!(?resp.request_id, status = resp.status, "forwarding HTTP response to pubsub"); - let response_topic = TunnelHttpResponseSubject::new( - runner_id, - &port_name, - resp.request_id, - ) - .to_string(); - - info!(%response_topic, ?resp.request_id, "publishing HTTP response to pubsub"); - if let Err(e) = ups_clone - .request_with_timeout( - &response_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing HTTP response; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?resp.request_id, "failed to publish HTTP response to pubsub"); - } - } else { - info!(?resp.request_id, "successfully published HTTP response to pubsub"); - } - } - MessageBody::ToClientWebSocketMessage(ws_msg) => { - info!(?ws_msg.web_socket_id, "forwarding WebSocket message to pubsub"); - // Forward WebSocket messages to the topic that pegboard-gateway subscribes to - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_msg.web_socket_id, - ) - .to_string(); - - info!(%ws_topic, ?ws_msg.web_socket_id, "publishing WebSocket message to pubsub"); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket message; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to pubsub"); - } - } else { - info!(?ws_msg.web_socket_id, "successfully published WebSocket message to pubsub"); - } - } - MessageBody::ToClientWebSocketOpen(ws_open) => { - info!(?ws_open.web_socket_id, "forwarding WebSocket open to pubsub"); - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_open.web_socket_id, - ) - .to_string(); + // Parse message + let msg = match versioned::RunnerMessage::deserialize_version( + &data, + protocol_version, + ) + .and_then(|x| x.into_latest()) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to deserialize message"); + continue; + } + }; + + // Determine reply to subject + let request_id = msg.request_id; + let reply_to = { + let active_requests = active_requests_clone.lock().await; + if let Some(req) = active_requests.get(&request_id) { + req.reply_to.clone() + } else { + tracing::warn!( + "no active request for tunnel message, may have timed out" + ); + continue; + } + }; - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket open; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to pubsub"); - } - } else { - info!(?ws_open.web_socket_id, "successfully published WebSocket open to pubsub"); - } - } - MessageBody::ToClientWebSocketClose(ws_close) => { - info!(?ws_close.web_socket_id, "forwarding WebSocket close to pubsub"); - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_close.web_socket_id, - ) - .to_string(); + // Remove active request entries when terminal + if is_message_kind_request_close(&msg.message_kind) { + let mut active_requests = active_requests_clone.lock().await; + active_requests.remove(&request_id); + } - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket close; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_hyper( - &ws_write_ws_to_pubsub, - ) - .await; - break; - } else { - error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to pubsub"); - } - } else { - info!(?ws_close.web_socket_id, "successfully published WebSocket close to pubsub"); - } - } - _ => { - // For other message types, we might not need to forward to pubsub - info!( - "Received non-response message from WebSocket, skipping pubsub forward" - ); - continue; - } + // Publish message to UPS + let message_serialized = + match versioned::PubSubMessage::latest(PubSubMessage { + request_id: msg.request_id, + message_id: msg.message_id, + reply_to: None, + message_kind: msg.message_kind, + }) + .serialize_with_embedded_version(PROTOCOL_VERSION) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize tunnel to gateway"); + continue; } - } - ::std::result::Result::Err(e) => { - error!(?e, "failed to deserialize tunnel message from WebSocket"); + }; + match ups_clone + .publish(&reply_to, &message_serialized, PublishOpts::one()) + .await + { + Result::Ok(_) => {} + Err(err) => { + tracing::error!(?err, "error publishing ups message"); } } } - ::std::result::Result::Ok(WsMessage::Close(_)) => { - info!(?runner_id, "WebSocket closed"); + Result::Ok(WsMessage::Close(_)) => { + tracing::info!(?runner_key_clone, "WebSocket closed"); break; } - ::std::result::Result::Ok(_) => { + Result::Ok(_) => { // Ignore other message types } Err(e) => { - error!(?e, "WebSocket error"); + tracing::error!(?e, "WebSocket error"); break; } } } - info!("WebSocket to pubsub forwarding task ended"); - - // Clean up connection - connections_clone.write().await.remove(&connection_id); + tracing::info!("WebSocket to pubsub forwarding task ended"); }); // Wait for either task to complete tokio::select! { _ = pubsub_to_ws => { - info!("pubsub to WebSocket task completed"); + tracing::info!("pubsub to WebSocket task completed"); } _ = ws_to_pubsub => { - info!("WebSocket to pubsub task completed"); + tracing::info!("WebSocket to pubsub task completed"); } } // Clean up - connections.write().await.remove(&connection_id); - info!(?runner_id, "connection closed"); + tracing::info!(?runner_key, "connection closed"); std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()) } } -// Keep the old start function for backward compatibility in tests -pub async fn start(config: rivet_config::Config, pools: Pools) -> Result<()> { - let cache = rivet_cache::CacheInner::from_env(&config, pools.clone())?; - let ctx = StandaloneCtx::new( - gas::db::DatabaseKv::from_pools(pools.clone()).await?, - config.clone(), - pools.clone(), - cache, - "pegboard-tunnel", - Id::new_v1(config.dc_label()), - Id::new_v1(config.dc_label()), - )?; - - main_loop(ctx).await -} - -async fn main_loop(ctx: gas::prelude::StandaloneCtx) -> Result<()> { - let connections: Connections = Arc::new(RwLock::new(HashMap::new())); - - // Start WebSocket server - // Use pegboard config since pegboard_tunnel doesn't exist - let server_addr = SocketAddr::new( - ctx.config().pegboard().host(), - ctx.config().pegboard().port(), - ); - - info!(?server_addr, "starting pegboard-tunnel"); - - let listener = TcpListener::bind(&server_addr).await?; - - // Accept connections - loop { - let (tcp_stream, addr) = listener.accept().await?; - let connections = connections.clone(); - let ctx = ctx.clone(); - - tokio::spawn(async move { - if let Err(e) = handle_connection(ctx, tcp_stream, addr, connections).await { - error!(?e, ?addr, "connection handler error"); - } - }); +fn is_message_kind_request_close(kind: &MessageKind) -> bool { + match kind { + // HTTP terminal states + MessageKind::ToClientResponseStart(resp) => !resp.stream, + MessageKind::ToClientResponseChunk(chunk) => chunk.finish, + MessageKind::ToClientResponseAbort => true, + // WebSocket terminal states (either side closes) + MessageKind::ToClientWebSocketClose(_) => true, + MessageKind::ToServerWebSocketClose(_) => true, + _ => false, } } - -async fn handle_connection( - ctx: gas::prelude::StandaloneCtx, - tcp_stream: tokio::net::TcpStream, - addr: std::net::SocketAddr, - connections: Connections, -) -> Result<()> { - info!(?addr, "new connection"); - - // Parse WebSocket upgrade request - let ws_stream = accept_async(tcp_stream).await?; - - // For now, we'll expect the runner to send an initial message with its ID - // In production, this would be parsed from the URL path or headers - let runner_id = rivet_util::Id::nil(); // Placeholder - should be extracted from connection - let port_name = "default".to_string(); // Placeholder - should be extracted - - let connection_id = rivet_util::Id::nil(); - - // Subscribe to pubsub topic for this runner using raw pubsub client - let topic = TunnelHttpRunnerSubject::new(runner_id, &port_name).to_string(); - info!(%topic, ?runner_id, "subscribing to pubsub topic"); - - // Get UPS (UniversalPubSub) client - let ups = ctx.pools().ups()?; - let mut sub = ups.subscribe(&topic).await?; - - // Split WebSocket stream into read and write halves - let (ws_write, mut ws_read) = ws_stream.split(); - let ws_write = Arc::new(Mutex::new(ws_write)); - - // Store connection - let connection = Arc::new(RunnerConnection { - _runner_id: runner_id, - _port_name: port_name.clone(), - }); - - connections - .write() - .await - .insert(connection_id, connection.clone()); - - // Handle bidirectional message forwarding - let ws_write_clone = ws_write.clone(); - let connections_clone = connections.clone(); - let ups_clone = ups.clone(); - - // Task for forwarding pubsub -> WebSocket - let pubsub_to_ws = tokio::spawn(async move { - while let ::std::result::Result::Ok(NextOutput::Message(msg)) = sub.next().await { - // Ack message - match msg.reply(&[]).await { - Result::Ok(_) => {} - Err(err) => { - tracing::warn!(?err, "failed to ack gateway request response message") - } - }; - - // Forward raw message to WebSocket - let ws_msg = - tokio_tungstenite::tungstenite::Message::Binary(msg.payload.to_vec().into()); - { - let mut stream = ws_write_clone.lock().await; - if let Err(e) = stream.send(ws_msg).await { - error!(?e, "failed to send message to WebSocket"); - break; - } - } - } - }); - - // Task for forwarding WebSocket -> pubsub - let ws_write_ws_to_pubsub = ws_write.clone(); - let ws_to_pubsub = tokio::spawn(async move { - while let Some(msg) = ws_read.next().await { - match msg { - ::std::result::Result::Ok(tokio_tungstenite::tungstenite::Message::Binary( - data, - )) => { - // Parse the tunnel message to extract request_id - match versioned::TunnelMessage::deserialize(&data) { - ::std::result::Result::Ok(tunnel_msg) => { - // Handle different message types - match &tunnel_msg.body { - MessageBody::ToClientResponseStart(resp) => { - let response_topic = TunnelHttpResponseSubject::new( - runner_id, - &port_name, - resp.request_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &response_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing HTTP response; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?resp.request_id, "failed to publish HTTP response to pubsub"); - } - } - } - MessageBody::ToClientWebSocketMessage(ws_msg) => { - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_msg.web_socket_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket message; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?ws_msg.web_socket_id, "failed to publish WebSocket message to pubsub"); - } - } - } - MessageBody::ToClientWebSocketOpen(ws_open) => { - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_open.web_socket_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket open; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?ws_open.web_socket_id, "failed to publish WebSocket open to pubsub"); - } - } - } - MessageBody::ToClientWebSocketClose(ws_close) => { - let ws_topic = TunnelHttpWebSocketSubject::new( - runner_id, - &port_name, - ws_close.web_socket_id, - ) - .to_string(); - - if let Err(e) = ups_clone - .request_with_timeout( - &ws_topic, - &data.to_vec(), - UPS_REQ_TIMEOUT, - ) - .await - { - let err_any: anyhow::Error = e.into(); - if is_tunnel_closed_error(&err_any) { - info!( - "tunnel closed while publishing WebSocket close; closing client websocket" - ); - // Close client websocket with reason - send_tunnel_closed_close_tokio(&ws_write_ws_to_pubsub) - .await; - break; - } else { - error!(?err_any, ?ws_close.web_socket_id, "failed to publish WebSocket close to pubsub"); - } - } - } - _ => { - // For other message types, we might not need to forward to pubsub - info!( - "Received non-response message from WebSocket, skipping pubsub forward" - ); - continue; - } - } - } - ::std::result::Result::Err(e) => { - error!(?e, "failed to deserialize tunnel message from WebSocket"); - } - } - } - ::std::result::Result::Ok(tokio_tungstenite::tungstenite::Message::Close(_)) => { - info!(?runner_id, "WebSocket closed"); - break; - } - ::std::result::Result::Ok(_) => { - // Ignore other message types - } - Err(e) => { - error!(?e, "WebSocket error"); - break; - } - } - } - - // Clean up connection - connections_clone.write().await.remove(&connection_id); - }); - - // Wait for either task to complete - tokio::select! { - _ = pubsub_to_ws => { - info!("pubsub to WebSocket task completed"); - } - _ = ws_to_pubsub => { - info!("WebSocket to pubsub task completed"); - } - } - - // Clean up - connections.write().await.remove(&connection_id); - info!(?runner_id, "connection closed"); - - Ok(()) -} - -/// Determines if the tunnel is closed by if the UPS service is no longer responding. -fn is_tunnel_closed_error(err: &anyhow::Error) -> bool { - if let Some(err) = err - .chain() - .find_map(|x| x.downcast_ref::()) - && err.group() == "ups" - && err.code() == "request_timeout" - { - true - } else { - false - } -} - -// Helper: Build and send a standard tunnel-closed Close frame (hyper-tungstenite) -fn tunnel_closed_close_msg_hyper() -> WsMessage { - WsMessage::Close(Some(WsCloseFrame { - code: WsCloseCode::Error, - reason: WsUtf8Bytes::from_static("Tunnel closed"), - })) -} - -// Helper: Build and send a standard tunnel-closed Close frame (tokio-tungstenite) -fn tunnel_closed_close_msg_tokio() -> tokio_tungstenite::tungstenite::Message { - tokio_tungstenite::tungstenite::Message::Close(Some( - tokio_tungstenite::tungstenite::protocol::frame::CloseFrame { - code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: tokio_tungstenite::tungstenite::Utf8Bytes::from_static("Tunnel closed"), - }, - )) -} - -// Helper: Send the tunnel-closed Close frame on a hyper-tungstenite sink -async fn send_tunnel_closed_close_hyper(ws_write: &tokio::sync::Mutex) -where - S: futures::Sink + Unpin, -{ - let mut stream = ws_write.lock().await; - let _ = stream.send(tunnel_closed_close_msg_hyper()).await; -} - -// Helper: Send the tunnel-closed Close frame on a tokio-tungstenite sink -async fn send_tunnel_closed_close_tokio(ws_write: &tokio::sync::Mutex) -where - S: futures::Sink + Unpin, -{ - let mut stream = ws_write.lock().await; - let _ = stream.send(tunnel_closed_close_msg_tokio()).await; -} diff --git a/packages/core/pegboard-tunnel/tests/integration.rs b/packages/core/pegboard-tunnel/tests/integration.rs index 2c20bfe638..70051af60a 100644 --- a/packages/core/pegboard-tunnel/tests/integration.rs +++ b/packages/core/pegboard-tunnel/tests/integration.rs @@ -91,7 +91,9 @@ async fn test_pubsub_to_websocket( }; // Serialize the message - let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message))?; + let serialized = versioned::RunnerMessage::serialize_with_embedded_version( + versioned::RunnerMessage::V1(message), + )?; // Publish to pubsub topic using proper subject let topic = TunnelHttpRunnerSubject::new(&runner_id.to_string(), port_name).to_string(); @@ -105,7 +107,7 @@ async fn test_pubsub_to_websocket( match received? { WsMessage::Binary(data) => { // Deserialize and verify the message - let tunnel_msg = versioned::TunnelMessage::deserialize(&data)?; + let tunnel_msg = versioned::RunnerMessage::deserialize_with_embedded_version(&data)?; match tunnel_msg.body { MessageBody::ToServerRequestStart(req) => { assert_eq!(req.request_id, request_id); @@ -150,7 +152,9 @@ async fn test_websocket_to_pubsub( }; // Serialize and send via WebSocket - let serialized = versioned::TunnelMessage::serialize(versioned::TunnelMessage::V1(message))?; + let serialized = versioned::RunnerMessage::serialize_with_embedded_version( + versioned::RunnerMessage::V1(message), + )?; ws_stream.send(WsMessage::Binary(serialized.into())).await?; // Wait for message on pubsub @@ -159,7 +163,8 @@ async fn test_websocket_to_pubsub( match received { universalpubsub::pubsub::NextOutput::Message(msg) => { // Deserialize and verify the message - let tunnel_msg = versioned::TunnelMessage::deserialize(&msg.payload)?; + let tunnel_msg = + versioned::RunnerMessage::deserialize_with_embedded_version(&msg.payload)?; match tunnel_msg.body { MessageBody::ToClientResponseStart(resp) => { assert_eq!(resp.request_id, request_id); diff --git a/packages/infra/engine/tests/actors_lifecycle.rs b/packages/infra/engine/tests/actors_lifecycle.rs index 85b991eb4a..46aeff7e9d 100644 --- a/packages/infra/engine/tests/actors_lifecycle.rs +++ b/packages/infra/engine/tests/actors_lifecycle.rs @@ -4,7 +4,7 @@ use std::time::Duration; #[test] fn actor_lifecycle_single_dc() { - common::run(common::TestOpts::new(2), |ctx| async move { + common::run(common::TestOpts::new(1), |ctx| async move { actor_lifecycle_inner(&ctx, false).await; }); } @@ -29,10 +29,6 @@ async fn actor_lifecycle_inner(ctx: &common::TestCtx, multi_dc: bool) { let actor_id = common::create_actor(&namespace, target_dc.guard_port()).await; - // TODO: This is a race condition. we might need to move this after the guard ping since guard - // correctly waits for the actor to start. - tokio::time::sleep(Duration::from_millis(500)).await; - // Test ping via guard let ping_response = common::ping_actor_via_guard(ctx.leader_dc().guard_port(), &actor_id, "main").await; @@ -52,11 +48,10 @@ async fn actor_lifecycle_inner(ctx: &common::TestCtx, multi_dc: bool) { // Destroy tracing::info!("destroying actor"); - tokio::time::sleep(Duration::from_millis(500)).await; common::destroy_actor(&actor_id, &namespace, target_dc.guard_port()).await; - tokio::time::sleep(Duration::from_millis(500)).await; // Validate runner state + tokio::time::sleep(Duration::from_millis(500)).await; assert!( !runner.has_actor(&actor_id).await, "Runner should not have the actor after destroy" diff --git a/packages/services/pegboard/src/ops/runner/get_by_key.rs b/packages/services/pegboard/src/ops/runner/get_by_key.rs new file mode 100644 index 0000000000..5a6ea63bf1 --- /dev/null +++ b/packages/services/pegboard/src/ops/runner/get_by_key.rs @@ -0,0 +1,51 @@ +use anyhow::*; +use gas::prelude::*; +use rivet_types::runners::Runner; +use udb_util::{SERIALIZABLE, TxnExt}; +use universaldb as udb; + +use crate::keys; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Input { + pub namespace_id: Id, + pub name: String, + pub key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Output { + pub runner: Option, +} + +#[operation] +pub async fn pegboard_runner_get_by_key(ctx: &OperationCtx, input: &Input) -> Result { + let dc_name = ctx.config().dc_name()?.to_string(); + + let runner = ctx + .udb()? + .run(|tx, _mc| { + let dc_name = dc_name.to_string(); + let input = input.clone(); + async move { + let txs = tx.subspace(keys::subspace()); + + // Look up runner by key + let runner_by_key_key = + keys::ns::RunnerByKeyKey::new(input.namespace_id, input.name, input.key); + + let runner_data = txs.read_opt(&runner_by_key_key, SERIALIZABLE).await?; + + if let Some(data) = runner_data { + // Get full runner details using the runner_id + let runner = super::get::get_inner(&dc_name, &tx, data.runner_id).await?; + std::result::Result::<_, udb::FdbBindingError>::Ok(runner) + } else { + std::result::Result::<_, udb::FdbBindingError>::Ok(None) + } + } + }) + .await?; + + Ok(Output { runner }) +} diff --git a/packages/services/pegboard/src/ops/runner/mod.rs b/packages/services/pegboard/src/ops/runner/mod.rs index cd56c69267..1885c60f23 100644 --- a/packages/services/pegboard/src/ops/runner/mod.rs +++ b/packages/services/pegboard/src/ops/runner/mod.rs @@ -1,4 +1,5 @@ pub mod get; +pub mod get_by_key; pub mod list_for_ns; pub mod list_names; pub mod update_alloc_idx; diff --git a/packages/services/pegboard/src/pubsub_subjects.rs b/packages/services/pegboard/src/pubsub_subjects.rs index 13d39b1704..ae9e9438b6 100644 --- a/packages/services/pegboard/src/pubsub_subjects.rs +++ b/packages/services/pegboard/src/pubsub_subjects.rs @@ -1,77 +1,41 @@ -use rivet_util::Id; +use gas::prelude::*; -pub struct TunnelHttpRunnerSubject<'a> { - runner_id: Id, +pub struct TunnelRunnerReceiverSubject<'a> { + runner_key: &'a str, port_name: &'a str, } -impl<'a> TunnelHttpRunnerSubject<'a> { - pub fn new(runner_id: Id, port_name: &'a str) -> Self { +impl<'a> TunnelRunnerReceiverSubject<'a> { + pub fn new(runner_key: &'a str, port_name: &'a str) -> Self { Self { - runner_id, + runner_key, port_name, } } } -impl std::fmt::Display for TunnelHttpRunnerSubject<'_> { +impl std::fmt::Display for TunnelRunnerReceiverSubject<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "pegboard.tunnel.http.runner.{}.{}", - self.runner_id, self.port_name + "pegboard.tunnel.runner_receiver.{}.{}", + self.runner_key, self.port_name ) } } -pub struct TunnelHttpResponseSubject<'a> { - runner_id: Id, - port_name: &'a str, - request_id: u64, +pub struct TunnelGatewayReceiverSubject { + gateway_id: Uuid, } -impl<'a> TunnelHttpResponseSubject<'a> { - pub fn new(runner_id: Id, port_name: &'a str, request_id: u64) -> Self { - Self { - runner_id, - port_name, - request_id, - } +impl<'a> TunnelGatewayReceiverSubject { + pub fn new(gateway_id: Uuid) -> Self { + Self { gateway_id } } } -impl std::fmt::Display for TunnelHttpResponseSubject<'_> { +impl std::fmt::Display for TunnelGatewayReceiverSubject { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "pegboard.tunnel.http.request.{}.{}.{}", - self.runner_id, self.port_name, self.request_id - ) - } -} - -pub struct TunnelHttpWebSocketSubject<'a> { - runner_id: Id, - port_name: &'a str, - websocket_id: u64, -} - -impl<'a> TunnelHttpWebSocketSubject<'a> { - pub fn new(runner_id: Id, port_name: &'a str, websocket_id: u64) -> Self { - Self { - runner_id, - port_name, - websocket_id, - } - } -} - -impl std::fmt::Display for TunnelHttpWebSocketSubject<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "pegboard.tunnel.http.websocket.{}.{}.{}", - self.runner_id, self.port_name, self.websocket_id - ) + write!(f, "pegboard.gateway.receiver.{}", self.gateway_id) } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 62a80f8ad4..91185e953a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -685,6 +685,9 @@ importers: '@rivetkit/engine-tunnel-protocol': specifier: workspace:* version: link:../tunnel-protocol + uuid: + specifier: ^12.0.0 + version: 12.0.0 ws: specifier: ^8.18.3 version: 8.18.3 @@ -7241,6 +7244,10 @@ packages: util-deprecate@1.0.2: resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==} + uuid@12.0.0: + resolution: {integrity: sha512-USe1zesMYh4fjCA8ZH5+X5WIVD0J4V1Jksm1bFTVBX2F/cwSXt0RO5w/3UXbdLKmZX65MiWV+hwhSS8p6oBTGA==} + hasBin: true + validator@13.15.15: resolution: {integrity: sha512-BgWVbCI72aIQy937xbawcs+hrVaN/CZ2UwutgaJ36hGqRrLNM+f5LUT/YPRbo8IV/ASeFzXszezV+y2+rq3l8A==} engines: {node: '>= 0.10'} @@ -14748,6 +14755,8 @@ snapshots: util-deprecate@1.0.2: {} + uuid@12.0.0: {} + validator@13.15.15: {} vary@1.1.2: {} diff --git a/sdks/rust/runner-protocol/src/protocol.rs b/sdks/rust/runner-protocol/src/protocol.rs index 4ca049f634..a037941e2a 100644 --- a/sdks/rust/runner-protocol/src/protocol.rs +++ b/sdks/rust/runner-protocol/src/protocol.rs @@ -20,9 +20,7 @@ pub enum ToClient { #[serde(rename_all = "snake_case")] pub enum ToServer { Init { - runner_id: Option, name: String, - key: String, version: u32, total_slots: u32, diff --git a/sdks/rust/runner-protocol/src/versioned.rs b/sdks/rust/runner-protocol/src/versioned.rs index 2b61c2c2dc..34c54168b6 100644 --- a/sdks/rust/runner-protocol/src/versioned.rs +++ b/sdks/rust/runner-protocol/src/versioned.rs @@ -285,9 +285,7 @@ impl TryFrom for protocol::ToServer { fn try_from(value: v1::ToServer) -> Result { match value { v1::ToServer::ToServerInit(init) => Ok(protocol::ToServer::Init { - runner_id: init.runner_id.map(|id| util::Id::parse(&id)).transpose()?, name: init.name, - key: init.key, version: init.version, total_slots: init.total_slots, addresses_http: init diff --git a/sdks/rust/tunnel-protocol/build.rs b/sdks/rust/tunnel-protocol/build.rs index 453be2c763..f43ed92983 100644 --- a/sdks/rust/tunnel-protocol/build.rs +++ b/sdks/rust/tunnel-protocol/build.rs @@ -1,4 +1,8 @@ -use std::{env, fs, path::Path}; +use std::{ + env, fs, + path::{Path, PathBuf}, + process::Command, +}; use indoc::formatdoc; @@ -52,6 +56,78 @@ mod rust { } } +mod typescript { + use super::*; + + pub fn generate_sdk(schema_dir: &Path) { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let workspace_root = Path::new(&manifest_dir) + .parent() + .and_then(|p| p.parent()) + .and_then(|p| p.parent()) + .expect("Failed to find workspace root"); + + let sdk_dir = workspace_root + .join("sdks") + .join("typescript") + .join("tunnel-protocol"); + let src_dir = sdk_dir.join("src"); + + let highest_version_path = super::find_highest_version(schema_dir); + + let _ = fs::remove_dir_all(&src_dir); + if let Err(e) = fs::create_dir_all(&src_dir) { + panic!("Failed to create SDK directory: {}", e); + } + + let output = + Command::new(workspace_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js")) + .arg("compile") + .arg("--generator") + .arg("ts") + .arg(highest_version_path) + .arg("-o") + .arg(src_dir.join("index.ts")) + .output() + .expect("Failed to execute bare compiler for TypeScript"); + + if !output.status.success() { + panic!( + "BARE TypeScript generation failed: {}", + String::from_utf8_lossy(&output.stderr), + ); + } + } +} + +fn find_highest_version(schema_dir: &Path) -> PathBuf { + let mut highest_version = 0; + let mut highest_version_path = PathBuf::new(); + + for entry in fs::read_dir(schema_dir).unwrap().flatten() { + if !entry.path().is_dir() { + let path = entry.path(); + let bare_name = path + .file_name() + .unwrap() + .to_str() + .unwrap() + .split_once('.') + .unwrap() + .0; + + if let Ok(version) = bare_name[1..].parse::() { + if version > highest_version { + highest_version = version; + highest_version_path = path; + } + } + } + } + + highest_version_path +} + fn main() { let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); let workspace_root = Path::new(&manifest_dir) @@ -68,4 +144,15 @@ fn main() { println!("cargo:rerun-if-changed={}", schema_dir.display()); rust::generate_sdk(&schema_dir); + + // Check if cli.js exists before attempting TypeScript generation + let cli_js_path = workspace_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js"); + if cli_js_path.exists() { + typescript::generate_sdk(&schema_dir); + } else { + println!( + "cargo:warning=TypeScript SDK generation skipped: cli.js not found at {}. Run `pnpm install` to install.", + cli_js_path.display() + ); + } } diff --git a/sdks/rust/tunnel-protocol/src/versioned.rs b/sdks/rust/tunnel-protocol/src/versioned.rs index d37e5d5447..f5208f27fb 100644 --- a/sdks/rust/tunnel-protocol/src/versioned.rs +++ b/sdks/rust/tunnel-protocol/src/versioned.rs @@ -3,20 +3,20 @@ use versioned_data_util::OwnedVersionedData; use crate::{PROTOCOL_VERSION, generated::v1}; -pub enum TunnelMessage { - V1(v1::TunnelMessage), +pub enum RunnerMessage { + V1(v1::RunnerMessage), } -impl OwnedVersionedData for TunnelMessage { - type Latest = v1::TunnelMessage; +impl OwnedVersionedData for RunnerMessage { + type Latest = v1::RunnerMessage; - fn latest(latest: v1::TunnelMessage) -> Self { - TunnelMessage::V1(latest) + fn latest(latest: v1::RunnerMessage) -> Self { + RunnerMessage::V1(latest) } fn into_latest(self) -> Result { #[allow(irrefutable_let_patterns)] - if let TunnelMessage::V1(data) = self { + if let RunnerMessage::V1(data) = self { Ok(data) } else { bail!("version not latest"); @@ -25,20 +25,64 @@ impl OwnedVersionedData for TunnelMessage { fn deserialize_version(payload: &[u8], version: u16) -> Result { match version { - 1 => Ok(TunnelMessage::V1(serde_bare::from_slice(payload)?)), + 1 => Ok(RunnerMessage::V1(serde_bare::from_slice(payload)?)), _ => bail!("invalid version: {version}"), } } fn serialize_version(self, _version: u16) -> Result> { match self { - TunnelMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), + RunnerMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), } } } -impl TunnelMessage { - pub fn deserialize(buf: &[u8]) -> Result { +impl RunnerMessage { + pub fn deserialize(buf: &[u8]) -> Result { + ::deserialize(buf, PROTOCOL_VERSION) + } + + pub fn serialize(self) -> Result> { + ::serialize(self, PROTOCOL_VERSION) + } +} + +pub enum PubSubMessage { + V1(v1::PubSubMessage), +} + +impl OwnedVersionedData for PubSubMessage { + type Latest = v1::PubSubMessage; + + fn latest(latest: v1::PubSubMessage) -> Self { + PubSubMessage::V1(latest) + } + + fn into_latest(self) -> Result { + #[allow(irrefutable_let_patterns)] + if let PubSubMessage::V1(data) = self { + Ok(data) + } else { + bail!("version not latest"); + } + } + + fn deserialize_version(payload: &[u8], version: u16) -> Result { + match version { + 1 => Ok(PubSubMessage::V1(serde_bare::from_slice(payload)?)), + _ => bail!("invalid version: {version}"), + } + } + + fn serialize_version(self, _version: u16) -> Result> { + match self { + PubSubMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), + } + } +} + +impl PubSubMessage { + pub fn deserialize(buf: &[u8]) -> Result { ::deserialize(buf, PROTOCOL_VERSION) } diff --git a/sdks/schemas/runner-protocol/v1.bare b/sdks/schemas/runner-protocol/v1.bare index ed99e59afb..5db5afcea4 100644 --- a/sdks/schemas/runner-protocol/v1.bare +++ b/sdks/schemas/runner-protocol/v1.bare @@ -133,9 +133,7 @@ type CommandWrapper struct { } type ToServerInit struct { - runnerId: optional name: str - key: str version: u32 totalSlots: u32 addressesHttp: optional> diff --git a/sdks/schemas/tunnel-protocol/v1.bare b/sdks/schemas/tunnel-protocol/v1.bare index 5405587c1b..f9e0e9c63f 100644 --- a/sdks/schemas/tunnel-protocol/v1.bare +++ b/sdks/schemas/tunnel-protocol/v1.bare @@ -1,15 +1,13 @@ -type RequestId u64 -type WebSocketId u64 +type RequestId data[16] # UUIDv4 +type MessageId data[16] # UUIDv4 type Id str -type StreamFinishReason enum { - COMPLETE - ABORT -} -# MARK: HTTP Request Forwarding +# MARK: Ack +type Ack void + +# MARK: HTTP type ToServerRequestStart struct { - requestId: RequestId actorId: Id method: str path: str @@ -19,17 +17,13 @@ type ToServerRequestStart struct { } type ToServerRequestChunk struct { - requestId: RequestId body: data + finish: bool } -type ToServerRequestFinish struct { - requestId: RequestId - reason: StreamFinishReason -} +type ToServerRequestAbort void type ToClientResponseStart struct { - requestId: RequestId status: u16 headers: map body: optional @@ -37,60 +31,52 @@ type ToClientResponseStart struct { } type ToClientResponseChunk struct { - requestId: RequestId body: data + finish: bool } -type ToClientResponseFinish struct { - requestId: RequestId - reason: StreamFinishReason -} +type ToClientResponseAbort void -# MARK: WebSocket Forwarding +# MARK: WebSocket type ToServerWebSocketOpen struct { actorId: Id - webSocketId: WebSocketId path: str headers: map } type ToServerWebSocketMessage struct { - webSocketId: WebSocketId data: data binary: bool } type ToServerWebSocketClose struct { - webSocketId: WebSocketId code: optional reason: optional } -type ToClientWebSocketOpen struct { - webSocketId: WebSocketId -} +type ToClientWebSocketOpen void type ToClientWebSocketMessage struct { - webSocketId: WebSocketId data: data binary: bool } type ToClientWebSocketClose struct { - webSocketId: WebSocketId code: optional reason: optional } # MARK: Message -type MessageBody union { +type MessageKind union { + Ack | + # HTTP ToServerRequestStart | ToServerRequestChunk | - ToServerRequestFinish | + ToServerRequestAbort | ToClientResponseStart | ToClientResponseChunk | - ToClientResponseFinish | + ToClientResponseAbort | # WebSocket ToServerWebSocketOpen | @@ -101,7 +87,19 @@ type MessageBody union { ToClientWebSocketClose } -# Main tunnel message -type TunnelMessage struct { - body: MessageBody +# MARK: Message sent over tunnel WebSocket +type RunnerMessage struct { + requestId: RequestId + messageId: MessageId + messageKind: MessageKind +} + +# MARK: Message sent over UPS +type PubSubMessage struct { + requestId: RequestId + messageId: MessageId + # Subject to send replies to. Only sent when opening a new request from gateway -> runner. + replyTo: optional + messageKind: MessageKind } + diff --git a/sdks/typescript/runner-protocol/src/index.ts b/sdks/typescript/runner-protocol/src/index.ts index 45a947c051..96c00212c6 100644 --- a/sdks/typescript/runner-protocol/src/index.ts +++ b/sdks/typescript/runner-protocol/src/index.ts @@ -590,18 +590,7 @@ export function writeCommandWrapper(bc: bare.ByteCursor, x: CommandWrapper): voi writeCommand(bc, x.inner) } -function read3(bc: bare.ByteCursor): Id | null { - return bare.readBool(bc) ? readId(bc) : null -} - -function write3(bc: bare.ByteCursor, x: Id | null): void { - bare.writeBool(bc, x != null) - if (x != null) { - writeId(bc, x) - } -} - -function read4(bc: bare.ByteCursor): ReadonlyMap { +function read3(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -616,7 +605,7 @@ function read4(bc: bare.ByteCursor): ReadonlyMap { return result } -function write4(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write3(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -624,18 +613,18 @@ function write4(bc: bare.ByteCursor, x: ReadonlyMap): } } -function read5(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read4(bc) : null +function read4(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read3(bc) : null } -function write5(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write4(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write4(bc, x) + write3(bc, x) } } -function read6(bc: bare.ByteCursor): ReadonlyMap { +function read5(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -650,7 +639,7 @@ function read6(bc: bare.ByteCursor): ReadonlyMap { return result } -function write6(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write5(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -658,18 +647,18 @@ function write6(bc: bare.ByteCursor, x: ReadonlyMap): } } -function read7(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read6(bc) : null +function read6(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read5(bc) : null } -function write7(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write6(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write6(bc, x) + write5(bc, x) } } -function read8(bc: bare.ByteCursor): ReadonlyMap { +function read7(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -684,7 +673,7 @@ function read8(bc: bare.ByteCursor): ReadonlyMap { return result } -function write8(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write7(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -692,18 +681,18 @@ function write8(bc: bare.ByteCursor, x: ReadonlyMap): } } -function read9(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read8(bc) : null +function read8(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read7(bc) : null } -function write9(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write8(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write8(bc, x) + write7(bc, x) } } -function read10(bc: bare.ByteCursor): ReadonlyMap { +function read9(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -718,7 +707,7 @@ function read10(bc: bare.ByteCursor): ReadonlyMap { return result } -function write10(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write9(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -726,22 +715,22 @@ function write10(bc: bare.ByteCursor, x: ReadonlyMap): void { } } -function read11(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read10(bc) : null +function read10(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read9(bc) : null } -function write11(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write10(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write10(bc, x) + write9(bc, x) } } -function read12(bc: bare.ByteCursor): Json | null { +function read11(bc: bare.ByteCursor): Json | null { return bare.readBool(bc) ? readJson(bc) : null } -function write12(bc: bare.ByteCursor, x: Json | null): void { +function write11(bc: bare.ByteCursor, x: Json | null): void { bare.writeBool(bc, x != null) if (x != null) { writeJson(bc, x) @@ -749,9 +738,7 @@ function write12(bc: bare.ByteCursor, x: Json | null): void { } export type ToServerInit = { - readonly runnerId: Id | null readonly name: string - readonly key: string readonly version: u32 readonly totalSlots: u32 readonly addressesHttp: ReadonlyMap | null @@ -764,32 +751,28 @@ export type ToServerInit = { export function readToServerInit(bc: bare.ByteCursor): ToServerInit { return { - runnerId: read3(bc), name: bare.readString(bc), - key: bare.readString(bc), version: bare.readU32(bc), totalSlots: bare.readU32(bc), - addressesHttp: read5(bc), - addressesTcp: read7(bc), - addressesUdp: read9(bc), + addressesHttp: read4(bc), + addressesTcp: read6(bc), + addressesUdp: read8(bc), lastCommandIdx: read1(bc), - prepopulateActorNames: read11(bc), - metadata: read12(bc), + prepopulateActorNames: read10(bc), + metadata: read11(bc), } } export function writeToServerInit(bc: bare.ByteCursor, x: ToServerInit): void { - write3(bc, x.runnerId) bare.writeString(bc, x.name) - bare.writeString(bc, x.key) bare.writeU32(bc, x.version) bare.writeU32(bc, x.totalSlots) - write5(bc, x.addressesHttp) - write7(bc, x.addressesTcp) - write9(bc, x.addressesUdp) + write4(bc, x.addressesHttp) + write6(bc, x.addressesTcp) + write8(bc, x.addressesUdp) write1(bc, x.lastCommandIdx) - write11(bc, x.prepopulateActorNames) - write12(bc, x.metadata) + write10(bc, x.prepopulateActorNames) + write11(bc, x.metadata) } export type ToServerEvents = readonly EventWrapper[] @@ -843,7 +826,7 @@ export function writeToServerPing(bc: bare.ByteCursor, x: ToServerPing): void { bare.writeI64(bc, x.ts) } -function read13(bc: bare.ByteCursor): readonly KvKey[] { +function read12(bc: bare.ByteCursor): readonly KvKey[] { const len = bare.readUintSafe(bc) if (len === 0) { return [] @@ -855,7 +838,7 @@ function read13(bc: bare.ByteCursor): readonly KvKey[] { return result } -function write13(bc: bare.ByteCursor, x: readonly KvKey[]): void { +function write12(bc: bare.ByteCursor, x: readonly KvKey[]): void { bare.writeUintSafe(bc, x.length) for (let i = 0; i < x.length; i++) { writeKvKey(bc, x[i]) @@ -868,30 +851,30 @@ export type KvGetRequest = { export function readKvGetRequest(bc: bare.ByteCursor): KvGetRequest { return { - keys: read13(bc), + keys: read12(bc), } } export function writeKvGetRequest(bc: bare.ByteCursor, x: KvGetRequest): void { - write13(bc, x.keys) + write12(bc, x.keys) } -function read14(bc: bare.ByteCursor): boolean | null { +function read13(bc: bare.ByteCursor): boolean | null { return bare.readBool(bc) ? bare.readBool(bc) : null } -function write14(bc: bare.ByteCursor, x: boolean | null): void { +function write13(bc: bare.ByteCursor, x: boolean | null): void { bare.writeBool(bc, x != null) if (x != null) { bare.writeBool(bc, x) } } -function read15(bc: bare.ByteCursor): u64 | null { +function read14(bc: bare.ByteCursor): u64 | null { return bare.readBool(bc) ? bare.readU64(bc) : null } -function write15(bc: bare.ByteCursor, x: u64 | null): void { +function write14(bc: bare.ByteCursor, x: u64 | null): void { bare.writeBool(bc, x != null) if (x != null) { bare.writeU64(bc, x) @@ -907,18 +890,18 @@ export type KvListRequest = { export function readKvListRequest(bc: bare.ByteCursor): KvListRequest { return { query: readKvListQuery(bc), - reverse: read14(bc), - limit: read15(bc), + reverse: read13(bc), + limit: read14(bc), } } export function writeKvListRequest(bc: bare.ByteCursor, x: KvListRequest): void { writeKvListQuery(bc, x.query) - write14(bc, x.reverse) - write15(bc, x.limit) + write13(bc, x.reverse) + write14(bc, x.limit) } -function read16(bc: bare.ByteCursor): readonly KvValue[] { +function read15(bc: bare.ByteCursor): readonly KvValue[] { const len = bare.readUintSafe(bc) if (len === 0) { return [] @@ -930,7 +913,7 @@ function read16(bc: bare.ByteCursor): readonly KvValue[] { return result } -function write16(bc: bare.ByteCursor, x: readonly KvValue[]): void { +function write15(bc: bare.ByteCursor, x: readonly KvValue[]): void { bare.writeUintSafe(bc, x.length) for (let i = 0; i < x.length; i++) { writeKvValue(bc, x[i]) @@ -944,14 +927,14 @@ export type KvPutRequest = { export function readKvPutRequest(bc: bare.ByteCursor): KvPutRequest { return { - keys: read13(bc), - values: read16(bc), + keys: read12(bc), + values: read15(bc), } } export function writeKvPutRequest(bc: bare.ByteCursor, x: KvPutRequest): void { - write13(bc, x.keys) - write16(bc, x.values) + write12(bc, x.keys) + write15(bc, x.values) } export type KvDeleteRequest = { @@ -960,12 +943,12 @@ export type KvDeleteRequest = { export function readKvDeleteRequest(bc: bare.ByteCursor): KvDeleteRequest { return { - keys: read13(bc), + keys: read12(bc), } } export function writeKvDeleteRequest(bc: bare.ByteCursor, x: KvDeleteRequest): void { - write13(bc, x.keys) + write12(bc, x.keys) } export type KvDropRequest = null @@ -1214,7 +1197,7 @@ export function writeKvErrorResponse(bc: bare.ByteCursor, x: KvErrorResponse): v bare.writeString(bc, x.message) } -function read17(bc: bare.ByteCursor): readonly KvMetadata[] { +function read16(bc: bare.ByteCursor): readonly KvMetadata[] { const len = bare.readUintSafe(bc) if (len === 0) { return [] @@ -1226,7 +1209,7 @@ function read17(bc: bare.ByteCursor): readonly KvMetadata[] { return result } -function write17(bc: bare.ByteCursor, x: readonly KvMetadata[]): void { +function write16(bc: bare.ByteCursor, x: readonly KvMetadata[]): void { bare.writeUintSafe(bc, x.length) for (let i = 0; i < x.length; i++) { writeKvMetadata(bc, x[i]) @@ -1241,16 +1224,16 @@ export type KvGetResponse = { export function readKvGetResponse(bc: bare.ByteCursor): KvGetResponse { return { - keys: read13(bc), - values: read16(bc), - metadata: read17(bc), + keys: read12(bc), + values: read15(bc), + metadata: read16(bc), } } export function writeKvGetResponse(bc: bare.ByteCursor, x: KvGetResponse): void { - write13(bc, x.keys) - write16(bc, x.values) - write17(bc, x.metadata) + write12(bc, x.keys) + write15(bc, x.values) + write16(bc, x.metadata) } export type KvListResponse = { @@ -1261,16 +1244,16 @@ export type KvListResponse = { export function readKvListResponse(bc: bare.ByteCursor): KvListResponse { return { - keys: read13(bc), - values: read16(bc), - metadata: read17(bc), + keys: read12(bc), + values: read15(bc), + metadata: read16(bc), } } export function writeKvListResponse(bc: bare.ByteCursor, x: KvListResponse): void { - write13(bc, x.keys) - write16(bc, x.values) - write17(bc, x.metadata) + write12(bc, x.keys) + write15(bc, x.values) + write16(bc, x.metadata) } export type KvPutResponse = null diff --git a/sdks/typescript/runner/package.json b/sdks/typescript/runner/package.json index 0b7bfdf4c3..984f065374 100644 --- a/sdks/typescript/runner/package.json +++ b/sdks/typescript/runner/package.json @@ -22,6 +22,7 @@ "dependencies": { "@rivetkit/engine-runner-protocol": "workspace:*", "@rivetkit/engine-tunnel-protocol": "workspace:*", + "uuid": "^12.0.0", "ws": "^8.18.3" }, "devDependencies": { diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 92be4cda50..7d26a37925 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -1,16 +1,18 @@ import WebSocket from "ws"; import { importWebSocket } from "./websocket.js"; import * as protocol from "@rivetkit/engine-runner-protocol"; -import { unreachable, calculateBackoff } from "./utils.js"; -import { Tunnel } from "./tunnel.js"; -import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter.js"; +import { unreachable, calculateBackoff } from "./utils"; +import { Tunnel } from "./tunnel"; +import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; const KV_EXPIRE: number = 30_000; -interface ActorInstance { +export interface ActorInstance { actorId: string; generation: number; config: ActorConfig; + requests: Set; // Track active request IDs + webSockets: Set; // Track active WebSocket IDs } export interface ActorConfig { @@ -60,6 +62,11 @@ interface KvRequestEntry { export class Runner { #config: RunnerConfig; + + get config(): RunnerConfig { + return this.#config; + } + #actors: Map = new Map(); #actorWebSockets: Map> = new Map(); @@ -110,7 +117,7 @@ export class Runner { // MARK: Manage actors sleepActor(actorId: string, generation?: number) { - const actor = this.#getActor(actorId, generation); + const actor = this.getActor(actorId, generation); if (!actor) return; // Keep the actor instance in memory during sleep @@ -126,7 +133,7 @@ export class Runner { // Unregister actor from tunnel if (this.#tunnel) { - this.#tunnel.unregisterActor(actorId); + this.#tunnel.unregisterActor(actor); } this.#sendActorStateUpdate(actorId, actor.generation, "stopped"); @@ -147,7 +154,7 @@ export class Runner { } } - #getActor(actorId: string, generation?: number): ActorInstance | undefined { + getActor(actorId: string, generation?: number): ActorInstance | undefined { const actor = this.#actors.get(actorId); if (!actor) { console.error(`Actor ${actorId} not found`); @@ -363,10 +370,10 @@ export class Runner { const wsEndpoint = endpoint .replace("http://", "ws://") .replace("https://", "wss://"); - return `${wsEndpoint}/v1?namespace=${encodeURIComponent(this.#config.namespace)}`; + return `${wsEndpoint}?protocol_version=1&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${encodeURIComponent(this.#config.runnerKey)}`; } - get pegboardRelayUrl() { + get pegboardTunnelUrl() { const endpoint = this.#config.pegboardRelayEndpoint || this.#config.pegboardEndpoint || @@ -374,26 +381,19 @@ export class Runner { const wsEndpoint = endpoint .replace("http://", "ws://") .replace("https://", "wss://"); - // Include runner ID if we have it - if (this.runnerId) { - return `${wsEndpoint}/tunnel?namespace=${encodeURIComponent(this.#config.namespace)}&runner_id=${this.runnerId}`; - } - return `${wsEndpoint}/tunnel?namespace=${encodeURIComponent(this.#config.namespace)}`; + return `${wsEndpoint}?protocol_version=1&namespace=${encodeURIComponent(this.#config.namespace)}&runner_key=${this.#config.runnerKey}`; } async #openTunnelAndWait(): Promise { return new Promise((resolve, reject) => { - const url = this.pegboardRelayUrl; + const url = this.pegboardTunnelUrl; //console.log("[RUNNER] Opening tunnel to:", url); //console.log("[RUNNER] Current runner ID:", this.runnerId || "none"); //console.log("[RUNNER] Active actors count:", this.#actors.size); let connected = false; - this.#tunnel = new Tunnel(url); - this.#tunnel.setCallbacks({ - fetch: this.#config.fetch, - websocket: this.#config.websocket, + this.#tunnel = new Tunnel(this, url, { onConnected: () => { if (!connected) { connected = true; @@ -410,35 +410,9 @@ export class Runner { }, }); this.#tunnel.start(); - - // Re-register all active actors with the new tunnel - for (const actorId of this.#actors.keys()) { - //console.log("[RUNNER] Re-registering actor with tunnel:", actorId); - this.#tunnel.registerActor(actorId); - } }); } - #openTunnel() { - const url = this.pegboardRelayUrl; - //console.log("[RUNNER] Opening tunnel to:", url); - //console.log("[RUNNER] Current runner ID:", this.runnerId || "none"); - //console.log("[RUNNER] Active actors count:", this.#actors.size); - - this.#tunnel = new Tunnel(url); - this.#tunnel.setCallbacks({ - fetch: this.#config.fetch, - websocket: this.#config.websocket, - }); - this.#tunnel.start(); - - // Re-register all active actors with the new tunnel - for (const actorId of this.#actors.keys()) { - //console.log("[RUNNER] Re-registering actor with tunnel:", actorId); - this.#tunnel.registerActor(actorId); - } - } - // MARK: Runner protocol async #openPegboardWebSocket() { const WS = await importWebSocket(); @@ -469,9 +443,7 @@ export class Runner { // Send init message const init: protocol.ToServerInit = { - runnerId: this.runnerId || null, name: this.#config.runnerName, - key: this.#config.runnerKey, version: this.#config.version, totalSlots: this.#config.totalSlots, addressesHttp: new Map(), // No addresses needed with tunnel @@ -560,18 +532,6 @@ export class Runner { // runnerLostThreshold: this.#runnerLostThreshold, //}); - // Only reopen tunnel if we didn't have a runner ID before - // This happens on reconnection after losing connection - if (!hadRunnerId && this.runnerId) { - // Reopen tunnel with runner ID - //console.log("[RUNNER] Received runner ID, reopening tunnel"); - if (this.#tunnel) { - //console.log("[RUNNER] Shutting down existing tunnel"); - this.#tunnel.shutdown(); - } - this.#openTunnel(); - } - // Resend events that haven't been acknowledged this.#resendUnacknowledgedEvents(init.lastEventIdx); @@ -664,21 +624,12 @@ export class Runner { actorId, generation, config: actorConfig, + requests: new Set(), + webSockets: new Set(), }; this.#actors.set(actorId, instance); - // Register actor with tunnel - if (this.#tunnel) { - //console.log("[RUNNER] Registering new actor with tunnel:", actorId); - this.#tunnel.registerActor(actorId); - } else { - console.error( - "[RUNNER] WARNING: No tunnel available to register actor:", - actorId, - ); - } - this.#sendActorStateUpdate(actorId, generation, "running"); // TODO: Add timeout to onActorStart diff --git a/sdks/typescript/runner/src/tunnel.ts b/sdks/typescript/runner/src/tunnel.ts index 7c868c44ce..4a00c878e2 100644 --- a/sdks/typescript/runner/src/tunnel.ts +++ b/sdks/typescript/runner/src/tunnel.ts @@ -1,162 +1,269 @@ import WebSocket from "ws"; import * as tunnel from "@rivetkit/engine-tunnel-protocol"; -import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter.js"; -import { calculateBackoff } from "./utils.js"; +import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; +import { calculateBackoff } from "./utils"; +import type { Runner, ActorInstance } from "./mod"; +import { v4 as uuidv4 } from "uuid"; + +const GC_INTERVAL = 60000; // 60 seconds +const MESSAGE_ACK_TIMEOUT = 5000; // 5 seconds + +interface PendingRequest { + resolve: (response: Response) => void; + reject: (error: Error) => void; + streamController?: ReadableStreamDefaultController; + actorId?: string; +} + +interface TunnelCallbacks { + onConnected(): void; + onDisconnected(): void; +} + +interface PendingMessage { + sentAt: number; + requestIdStr: string; +} export class Tunnel { #pegboardTunnelUrl: string; - #ws?: WebSocket; - #pendingRequests: Map void; - reject: (error: Error) => void; - streamController?: ReadableStreamDefaultController; - actorId?: string; - }> = new Map(); - #webSockets: Map = new Map(); + + #runner: Runner; + + #tunnelWs?: WebSocket; #shutdown = false; #reconnectTimeout?: NodeJS.Timeout; #reconnectAttempt = 0; - - // Track actors and their connections - #activeActors: Set = new Set(); - #actorRequests: Map> = new Map(); - #actorWebSockets: Map> = new Map(); - - // Callbacks - #onConnected?: () => void; - #onDisconnected?: () => void; - #fetchHandler?: (actorId: string, request: Request) => Promise; - #websocketHandler?: (actorId: string, ws: any, request: Request) => Promise; - - constructor(pegboardTunnelUrl: string) { - this.#pegboardTunnelUrl = pegboardTunnelUrl; - } - setCallbacks(options: { - onConnected?: () => void; - onDisconnected?: () => void; - fetch?: (actorId: string, request: Request) => Promise; - websocket?: (actorId: string, ws: any, request: Request) => Promise; - }) { - this.#onConnected = options.onConnected; - this.#onDisconnected = options.onDisconnected; - this.#fetchHandler = options.fetch; - this.#websocketHandler = options.websocket; + #actorPendingRequests: Map = new Map(); + #actorWebSockets: Map = new Map(); + + #pendingMessages: Map = new Map(); + #gcInterval?: NodeJS.Timeout; + + #callbacks: TunnelCallbacks; + + constructor( + runner: Runner, + pegboardTunnelUrl: string, + callbacks: TunnelCallbacks, + ) { + this.#pegboardTunnelUrl = pegboardTunnelUrl; + this.#runner = runner; + this.#callbacks = callbacks; } start(): void { - if (this.#ws?.readyState === WebSocket.OPEN) { + if (this.#tunnelWs?.readyState === WebSocket.OPEN) { return; } - + this.#connect(); + this.#startGarbageCollector(); } shutdown() { this.#shutdown = true; - + if (this.#reconnectTimeout) { clearTimeout(this.#reconnectTimeout); this.#reconnectTimeout = undefined; } - if (this.#ws) { - this.#ws.close(); - this.#ws = undefined; + if (this.#gcInterval) { + clearInterval(this.#gcInterval); + this.#gcInterval = undefined; + } + + if (this.#tunnelWs) { + this.#tunnelWs.close(); + this.#tunnelWs = undefined; } + // TODO: Should we use unregisterActor instead + // Reject all pending requests - for (const [_, request] of this.#pendingRequests) { + for (const [_, request] of this.#actorPendingRequests) { request.reject(new Error("Tunnel shutting down")); } - this.#pendingRequests.clear(); + this.#actorPendingRequests.clear(); // Close all WebSockets - for (const [_, ws] of this.#webSockets) { + for (const [_, ws] of this.#actorWebSockets) { ws.close(); } - this.#webSockets.clear(); - - // Clear actor tracking - this.#activeActors.clear(); - this.#actorRequests.clear(); this.#actorWebSockets.clear(); } - registerActor(actorId: string) { - this.#activeActors.add(actorId); - this.#actorRequests.set(actorId, new Set()); - this.#actorWebSockets.set(actorId, new Set()); + #sendMessage(requestId: tunnel.RequestId, messageKind: tunnel.MessageKind) { + if (!this.#tunnelWs || this.#tunnelWs.readyState !== WebSocket.OPEN) { + console.warn("Cannot send tunnel message, WebSocket not connected"); + return; + } + + // Build message + const messageId = generateUuidBuffer(); + + const requestIdStr = bufferToString(requestId); + this.#pendingMessages.set(bufferToString(messageId), { + sentAt: Date.now(), + requestIdStr, + }); + + // Send message + const message: tunnel.RunnerMessage = { + requestId, + messageId, + messageKind, + }; + + const encoded = tunnel.encodeRunnerMessage(message); + this.#tunnelWs.send(encoded); } - unregisterActor(actorId: string) { - this.#activeActors.delete(actorId); - - // Terminate all requests for this actor - const requests = this.#actorRequests.get(actorId); - if (requests) { - for (const requestId of requests) { - const pending = this.#pendingRequests.get(requestId); - if (pending) { - pending.reject(new Error(`Actor ${actorId} stopped`)); - this.#pendingRequests.delete(requestId); + #sendAck(requestId: tunnel.RequestId, messageId: tunnel.MessageId) { + if (!this.#tunnelWs || this.#tunnelWs.readyState !== WebSocket.OPEN) { + return; + } + + const message: tunnel.RunnerMessage = { + requestId, + messageId, + messageKind: { tag: "Ack", val: null }, + }; + + const encoded = tunnel.encodeRunnerMessage(message); + this.#tunnelWs.send(encoded); + } + + #startGarbageCollector() { + if (this.#gcInterval) { + clearInterval(this.#gcInterval); + } + + this.#gcInterval = setInterval(() => { + this.#gc(); + }, GC_INTERVAL); + } + + #gc() { + const now = Date.now(); + const messagesToDelete: string[] = []; + + for (const [messageId, pendingMessage] of this.#pendingMessages) { + // Check if message is older than timeout + if ( + now - pendingMessage.sentAt > MESSAGE_ACK_TIMEOUT + ) { + messagesToDelete.push(messageId); + + const requestIdStr = pendingMessage.requestIdStr; + + // Check if this is an HTTP request + const pendingRequest = + this.#actorPendingRequests.get(requestIdStr); + if (pendingRequest) { + // Reject the pending HTTP request + pendingRequest.reject( + new Error("Message acknowledgment timeout"), + ); + + // Close stream controller if it exists + if (pendingRequest.streamController) { + pendingRequest.streamController.error( + new Error("Message acknowledgment timeout"), + ); + } + + // Clean up from actorPendingRequests map + this.#actorPendingRequests.delete(requestIdStr); + } + + // Check if this is a WebSocket + const webSocket = this.#actorWebSockets.get(requestIdStr); + if (webSocket) { + // Close the WebSocket connection + webSocket.close(1000, "Message acknowledgment timeout"); + + // Clean up from actorWebSockets map + this.#actorWebSockets.delete(requestIdStr); } } - this.#actorRequests.delete(actorId); } - + + // Remove timed out messages + for (const messageId of messagesToDelete) { + this.#pendingMessages.delete(messageId); + console.warn(`Purged unacked message: ${messageId}`); + } + } + + unregisterActor(actor: ActorInstance) { + const actorId = actor.actorId; + + // Terminate all requests for this actor + for (const requestId of actor.requests) { + const pending = this.#actorPendingRequests.get(requestId); + if (pending) { + pending.reject(new Error(`Actor ${actorId} stopped`)); + this.#actorPendingRequests.delete(requestId); + } + } + actor.requests.clear(); + // Close all WebSockets for this actor - const webSockets = this.#actorWebSockets.get(actorId); - if (webSockets) { - for (const webSocketId of webSockets) { - const ws = this.#webSockets.get(webSocketId); - if (ws) { - ws.close(1000, "Actor stopped"); - this.#webSockets.delete(webSocketId); - } + for (const webSocketId of actor.webSockets) { + const ws = this.#actorWebSockets.get(webSocketId); + if (ws) { + ws.close(1000, "Actor stopped"); + this.#actorWebSockets.delete(webSocketId); } - this.#actorWebSockets.delete(actorId); } + actor.webSockets.clear(); } async #fetch(actorId: string, request: Request): Promise { // Validate actor exists - if (!this.#activeActors.has(actorId)) { - console.warn(`[TUNNEL] Ignoring request for unknown actor: ${actorId}`); + if (!this.#runner.hasActor(actorId)) { + console.warn( + `[TUNNEL] Ignoring request for unknown actor: ${actorId}`, + ); return new Response("Actor not found", { status: 404 }); } - - if (!this.#fetchHandler) { + + const fetchHandler = this.#runner.config.fetch(actorId, request); + + if (!fetchHandler) { return new Response("Not Implemented", { status: 501 }); } - - return this.#fetchHandler(actorId, request); + + return fetchHandler; } #connect() { if (this.#shutdown) return; try { - this.#ws = new WebSocket(this.#pegboardTunnelUrl, { + this.#tunnelWs = new WebSocket(this.#pegboardTunnelUrl, { headers: { "x-rivet-target": "tunnel", }, }); - this.#ws.binaryType = "arraybuffer"; + this.#tunnelWs.binaryType = "arraybuffer"; - this.#ws.addEventListener("open", () => { + this.#tunnelWs.addEventListener("open", () => { this.#reconnectAttempt = 0; - + if (this.#reconnectTimeout) { clearTimeout(this.#reconnectTimeout); this.#reconnectTimeout = undefined; } - this.#onConnected?.(); + this.#callbacks.onConnected(); }); - this.#ws.addEventListener("message", async (event) => { + this.#tunnelWs.addEventListener("message", async (event) => { try { await this.#handleMessage(event.data as ArrayBuffer); } catch (error) { @@ -164,12 +271,12 @@ export class Tunnel { } }); - this.#ws.addEventListener("error", (event) => { + this.#tunnelWs.addEventListener("error", (event) => { console.error("Tunnel WebSocket error:", event); }); - this.#ws.addEventListener("close", () => { - this.#onDisconnected?.(); + this.#tunnelWs.addEventListener("close", () => { + this.#callbacks.onDisconnected(); if (!this.#shutdown) { this.#scheduleReconnect(); @@ -192,9 +299,8 @@ export class Tunnel { multiplier: 2, jitter: true, }); - - this.#reconnectAttempt++; + this.#reconnectAttempt++; this.#reconnectTimeout = setTimeout(() => { this.#connect(); @@ -202,53 +308,97 @@ export class Tunnel { } async #handleMessage(data: ArrayBuffer) { - const message = tunnel.decodeTunnelMessage(new Uint8Array(data)); - - switch (message.body.tag) { - case "ToServerRequestStart": - await this.#handleRequestStart(message.body.val); - break; - case "ToServerRequestChunk": - await this.#handleRequestChunk(message.body.val); - break; - case "ToServerRequestFinish": - await this.#handleRequestFinish(message.body.val); - break; - case "ToServerWebSocketOpen": - await this.#handleWebSocketOpen(message.body.val); - break; - case "ToServerWebSocketMessage": - await this.#handleWebSocketMessage(message.body.val); - break; - case "ToServerWebSocketClose": - await this.#handleWebSocketClose(message.body.val); - break; - case "ToClientResponseStart": - this.#handleResponseStart(message.body.val); - break; - case "ToClientResponseChunk": - this.#handleResponseChunk(message.body.val); - break; - case "ToClientResponseFinish": - this.#handleResponseFinish(message.body.val); - break; - case "ToClientWebSocketOpen": - this.#handleWebSocketOpenResponse(message.body.val); - break; - case "ToClientWebSocketMessage": - this.#handleWebSocketMessageResponse(message.body.val); - break; - case "ToClientWebSocketClose": - this.#handleWebSocketCloseResponse(message.body.val); - break; + const message = tunnel.decodeRunnerMessage(new Uint8Array(data)); + + if (message.messageKind.tag === "Ack") { + // Mark pending message as acknowledged and remove it + const msgIdStr = bufferToString(message.messageId); + const pending = this.#pendingMessages.get(msgIdStr); + if (pending) { + this.#pendingMessages.delete(msgIdStr); + } + } else { + this.#sendAck(message.requestId, message.messageId); + switch (message.messageKind.tag) { + case "ToServerRequestStart": + await this.#handleRequestStart( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerRequestChunk": + await this.#handleRequestChunk( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerRequestAbort": + await this.#handleRequestAbort(message.requestId); + break; + case "ToServerWebSocketOpen": + await this.#handleWebSocketOpen( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerWebSocketMessage": + await this.#handleWebSocketMessage( + message.requestId, + message.messageKind.val, + ); + break; + case "ToServerWebSocketClose": + await this.#handleWebSocketClose( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientResponseStart": + this.#handleResponseStart( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientResponseChunk": + this.#handleResponseChunk( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientResponseAbort": + this.#handleResponseAbort(message.requestId); + break; + case "ToClientWebSocketOpen": + this.#handleWebSocketOpenResponse( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientWebSocketMessage": + this.#handleWebSocketMessageResponse( + message.requestId, + message.messageKind.val, + ); + break; + case "ToClientWebSocketClose": + this.#handleWebSocketCloseResponse( + message.requestId, + message.messageKind.val, + ); + break; + } } } - async #handleRequestStart(req: tunnel.ToServerRequestStart) { + async #handleRequestStart( + requestId: ArrayBuffer, + req: tunnel.ToServerRequestStart, + ) { // Track this request for the actor - const requests = this.#actorRequests.get(req.actorId); - if (requests) { - requests.add(req.requestId); + const requestIdStr = bufferToString(requestId); + const actor = this.#runner.getActor(req.actorId); + if (actor) { + actor.requests.add(requestIdStr); } try { @@ -271,12 +421,13 @@ export class Tunnel { const stream = new ReadableStream({ start: (controller) => { // Store controller for chunks - const existing = this.#pendingRequests.get(req.requestId); + const existing = + this.#actorPendingRequests.get(requestIdStr); if (existing) { existing.streamController = controller; existing.actorId = req.actorId; } else { - this.#pendingRequests.set(req.requestId, { + this.#actorPendingRequests.set(requestIdStr, { resolve: () => {}, reject: () => {}, streamController: controller, @@ -293,194 +444,193 @@ export class Tunnel { } as any); // Call fetch handler with validation - const response = await this.#fetch(req.actorId, streamingRequest); - await this.#sendResponse(req.requestId, response); + const response = await this.#fetch( + req.actorId, + streamingRequest, + ); + await this.#sendResponse(requestId, response); } else { // Non-streaming request const response = await this.#fetch(req.actorId, request); - await this.#sendResponse(req.requestId, response); + await this.#sendResponse(requestId, response); } } catch (error) { console.error("Error handling request:", error); - this.#sendResponseError(req.requestId, 500, "Internal Server Error"); + this.#sendResponseError(requestId, 500, "Internal Server Error"); } finally { // Clean up request tracking - const requests = this.#actorRequests.get(req.actorId); - if (requests) { - requests.delete(req.requestId); + const actor = this.#runner.getActor(req.actorId); + if (actor) { + actor.requests.delete(requestIdStr); } } } - async #handleRequestChunk(chunk: tunnel.ToServerRequestChunk) { - const pending = this.#pendingRequests.get(chunk.requestId); + async #handleRequestChunk( + requestId: ArrayBuffer, + chunk: tunnel.ToServerRequestChunk, + ) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { pending.streamController.enqueue(new Uint8Array(chunk.body)); + if (chunk.finish) { + pending.streamController.close(); + this.#actorPendingRequests.delete(requestIdStr); + } } } - async #handleRequestFinish(finish: tunnel.ToServerRequestFinish) { - const pending = this.#pendingRequests.get(finish.requestId); + async #handleRequestAbort(requestId: ArrayBuffer) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { - if (finish.reason === tunnel.StreamFinishReason.Complete) { - pending.streamController.close(); - } else { - pending.streamController.error(new Error("Request aborted")); - } + pending.streamController.error(new Error("Request aborted")); } - this.#pendingRequests.delete(finish.requestId); + this.#actorPendingRequests.delete(requestIdStr); } - async #sendResponse(requestId: bigint, response: Response) { + async #sendResponse(requestId: ArrayBuffer, response: Response) { // Always treat responses as non-streaming for now // In the future, we could detect streaming responses based on: // - Transfer-Encoding: chunked // - Content-Type: text/event-stream // - Explicit stream flag from the handler - + // Read the body first to get the actual content const body = response.body ? await response.arrayBuffer() : null; - + // Convert headers to map and add Content-Length if not present const headers = new Map(); response.headers.forEach((value, key) => { headers.set(key, value); }); - + // Add Content-Length header if we have a body and it's not already set if (body && !headers.has("content-length")) { headers.set("content-length", String(body.byteLength)); } // Send as non-streaming response - this.#send({ - body: { - tag: "ToClientResponseStart", - val: { - requestId, - status: response.status as tunnel.u16, - headers, - body: body || null, - stream: false, - }, + this.#sendMessage(requestId, { + tag: "ToClientResponseStart", + val: { + status: response.status as tunnel.u16, + headers, + body: body || null, + stream: false, }, }); } - #sendResponseError(requestId: bigint, status: number, message: string) { + #sendResponseError( + requestId: ArrayBuffer, + status: number, + message: string, + ) { const headers = new Map(); headers.set("content-type", "text/plain"); - this.#send({ - body: { - tag: "ToClientResponseStart", - val: { - requestId, - status: status as tunnel.u16, - headers, - body: new TextEncoder().encode(message).buffer as ArrayBuffer, - stream: false, - }, + this.#sendMessage(requestId, { + tag: "ToClientResponseStart", + val: { + status: status as tunnel.u16, + headers, + body: new TextEncoder().encode(message).buffer as ArrayBuffer, + stream: false, }, }); } - async #handleWebSocketOpen(open: tunnel.ToServerWebSocketOpen) { + async #handleWebSocketOpen( + requestId: ArrayBuffer, + open: tunnel.ToServerWebSocketOpen, + ) { + const webSocketId = bufferToString(requestId); // Validate actor exists - if (!this.#activeActors.has(open.actorId)) { - console.warn(`Ignoring WebSocket for unknown actor: ${open.actorId}`); + const actor = this.#runner.getActor(open.actorId); + if (!actor) { + console.warn( + `Ignoring WebSocket for unknown actor: ${open.actorId}`, + ); // Send close immediately - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: 1011, - reason: "Actor not found", - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: 1011, + reason: "Actor not found", }, }); return; } - if (!this.#websocketHandler) { + const websocketHandler = this.#runner.config.websocket; + + if (!websocketHandler) { console.error("No websocket handler configured for tunnel"); // Send close immediately - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: 1011, - reason: "Not Implemented", - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: 1011, + reason: "Not Implemented", }, }); return; } // Track this WebSocket for the actor - const webSockets = this.#actorWebSockets.get(open.actorId); - if (webSockets) { - webSockets.add(open.webSocketId); + if (actor) { + actor.webSockets.add(webSocketId); } try { // Create WebSocket adapter const adapter = new WebSocketTunnelAdapter( - open.webSocketId, + webSocketId, (data: ArrayBuffer | string, isBinary: boolean) => { // Send message through tunnel - const dataBuffer = typeof data === "string" - ? new TextEncoder().encode(data).buffer as ArrayBuffer - : data; - - this.#send({ - body: { - tag: "ToClientWebSocketMessage", - val: { - webSocketId: open.webSocketId, - data: dataBuffer, - binary: isBinary, - }, + const dataBuffer = + typeof data === "string" + ? (new TextEncoder().encode(data) + .buffer as ArrayBuffer) + : data; + + this.#sendMessage(requestId, { + tag: "ToClientWebSocketMessage", + val: { + data: dataBuffer, + binary: isBinary, }, }); }, (code?: number, reason?: string) => { // Send close through tunnel - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: code || null, - reason: reason || null, - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: code || null, + reason: reason || null, }, }); - + // Remove from map - this.#webSockets.delete(open.webSocketId); - + this.#actorWebSockets.delete(webSocketId); + // Clean up actor tracking - const webSockets = this.#actorWebSockets.get(open.actorId); - if (webSockets) { - webSockets.delete(open.webSocketId); + if (actor) { + actor.webSockets.delete(webSocketId); } - } + }, ); // Store adapter - this.#webSockets.set(open.webSocketId, adapter); + this.#actorWebSockets.set(webSocketId, adapter); // Send open confirmation - this.#send({ - body: { - tag: "ToClientWebSocketOpen", - val: { - webSocketId: open.webSocketId, - }, - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketOpen", + val: null, }); // Notify adapter that connection is open @@ -490,7 +640,10 @@ export class Tunnel { // Include original headers from the open message const headerInit: Record = {}; if (open.headers) { - for (const [k, v] of open.headers as ReadonlyMap) { + for (const [k, v] of open.headers as ReadonlyMap< + string, + string + >) { headerInit[k] = v; } } @@ -504,55 +657,68 @@ export class Tunnel { }); // Call websocket handler - await this.#websocketHandler(open.actorId, adapter, request); + await websocketHandler(open.actorId, adapter, request); } catch (error) { console.error("Error handling WebSocket open:", error); - + // Send close on error - this.#send({ - body: { - tag: "ToClientWebSocketClose", - val: { - webSocketId: open.webSocketId, - code: 1011, - reason: "Server Error", - }, + this.#sendMessage(requestId, { + tag: "ToClientWebSocketClose", + val: { + code: 1011, + reason: "Server Error", }, }); - - this.#webSockets.delete(open.webSocketId); - + + this.#actorWebSockets.delete(webSocketId); + // Clean up actor tracking - const webSockets = this.#actorWebSockets.get(open.actorId); - if (webSockets) { - webSockets.delete(open.webSocketId); + if (actor) { + actor.webSockets.delete(webSocketId); } } } - async #handleWebSocketMessage(msg: tunnel.ToServerWebSocketMessage) { - const adapter = this.#webSockets.get(msg.webSocketId); + async #handleWebSocketMessage( + requestId: ArrayBuffer, + msg: tunnel.ToServerWebSocketMessage, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { const data = msg.binary ? new Uint8Array(msg.data) : new TextDecoder().decode(new Uint8Array(msg.data)); - + adapter._handleMessage(data, msg.binary); } } - async #handleWebSocketClose(close: tunnel.ToServerWebSocketClose) { - const adapter = this.#webSockets.get(close.webSocketId); + async #handleWebSocketClose( + requestId: ArrayBuffer, + close: tunnel.ToServerWebSocketClose, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { - adapter._handleClose(close.code || undefined, close.reason || undefined); - this.#webSockets.delete(close.webSocketId); + adapter._handleClose( + close.code || undefined, + close.reason || undefined, + ); + this.#actorWebSockets.delete(webSocketId); } } - #handleResponseStart(resp: tunnel.ToClientResponseStart) { - const pending = this.#pendingRequests.get(resp.requestId); + #handleResponseStart( + requestId: ArrayBuffer, + resp: tunnel.ToClientResponseStart, + ) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (!pending) { - console.warn(`Received response for unknown request ${resp.requestId}`); + console.warn( + `Received response for unknown request ${requestIdStr}`, + ); return; } @@ -585,62 +751,84 @@ export class Tunnel { }); pending.resolve(response); - this.#pendingRequests.delete(resp.requestId); + this.#actorPendingRequests.delete(requestIdStr); } } - #handleResponseChunk(chunk: tunnel.ToClientResponseChunk) { - const pending = this.#pendingRequests.get(chunk.requestId); + #handleResponseChunk( + requestId: ArrayBuffer, + chunk: tunnel.ToClientResponseChunk, + ) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { pending.streamController.enqueue(new Uint8Array(chunk.body)); + if (chunk.finish) { + pending.streamController.close(); + this.#actorPendingRequests.delete(requestIdStr); + } } } - #handleResponseFinish(finish: tunnel.ToClientResponseFinish) { - const pending = this.#pendingRequests.get(finish.requestId); + #handleResponseAbort(requestId: ArrayBuffer) { + const requestIdStr = bufferToString(requestId); + const pending = this.#actorPendingRequests.get(requestIdStr); if (pending?.streamController) { - if (finish.reason === tunnel.StreamFinishReason.Complete) { - pending.streamController.close(); - } else { - pending.streamController.error(new Error("Response aborted")); - } + pending.streamController.error(new Error("Response aborted")); } - this.#pendingRequests.delete(finish.requestId); + this.#actorPendingRequests.delete(requestIdStr); } - #handleWebSocketOpenResponse(open: tunnel.ToClientWebSocketOpen) { - const adapter = this.#webSockets.get(open.webSocketId); + #handleWebSocketOpenResponse( + requestId: ArrayBuffer, + open: tunnel.ToClientWebSocketOpen, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { adapter._handleOpen(); } } - #handleWebSocketMessageResponse(msg: tunnel.ToClientWebSocketMessage) { - const adapter = this.#webSockets.get(msg.webSocketId); + #handleWebSocketMessageResponse( + requestId: ArrayBuffer, + msg: tunnel.ToClientWebSocketMessage, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { const data = msg.binary ? new Uint8Array(msg.data) : new TextDecoder().decode(new Uint8Array(msg.data)); - + adapter._handleMessage(data, msg.binary); } } - #handleWebSocketCloseResponse(close: tunnel.ToClientWebSocketClose) { - const adapter = this.#webSockets.get(close.webSocketId); + #handleWebSocketCloseResponse( + requestId: ArrayBuffer, + close: tunnel.ToClientWebSocketClose, + ) { + const webSocketId = bufferToString(requestId); + const adapter = this.#actorWebSockets.get(webSocketId); if (adapter) { - adapter._handleClose(close.code || undefined, close.reason || undefined); - this.#webSockets.delete(close.webSocketId); + adapter._handleClose( + close.code || undefined, + close.reason || undefined, + ); + this.#actorWebSockets.delete(webSocketId); } } +} - #send(message: tunnel.TunnelMessage) { - if (!this.#ws || this.#ws.readyState !== WebSocket.OPEN) { - console.warn("Cannot send tunnel message, WebSocket not connected"); - return; - } +/** Converts a buffer to a string. Used for storing strings in a lookup map. */ +function bufferToString(buffer: ArrayBuffer): string { + return Buffer.from(buffer).toString("base64"); +} - const encoded = tunnel.encodeTunnelMessage(message); - this.#ws.send(encoded); - } +/** Generates a UUID as bytes. */ +function generateUuidBuffer(): ArrayBuffer { + const buffer = new Uint8Array(16); + uuidv4(undefined, buffer); + return buffer.buffer; } diff --git a/sdks/typescript/runner/src/websocket-tunnel-adapter.ts b/sdks/typescript/runner/src/websocket-tunnel-adapter.ts index bf097f8a5c..b2430d3238 100644 --- a/sdks/typescript/runner/src/websocket-tunnel-adapter.ts +++ b/sdks/typescript/runner/src/websocket-tunnel-adapter.ts @@ -2,7 +2,7 @@ // Implements a subset of the WebSocket interface for compatibility with runner code export class WebSocketTunnelAdapter { - #webSocketId: bigint; + #webSocketId: string; #readyState: number = 0; // CONNECTING #eventListeners: Map void>> = new Map(); #onopen: ((this: any, ev: any) => any) | null = null; @@ -24,7 +24,7 @@ export class WebSocketTunnelAdapter { }> = []; constructor( - webSocketId: bigint, + webSocketId: string, sendCallback: (data: ArrayBuffer | string, isBinary: boolean) => void, closeCallback: (code?: number, reason?: string) => void ) { diff --git a/sdks/typescript/tunnel-protocol/src/index.ts b/sdks/typescript/tunnel-protocol/src/index.ts index 0ea9a5b8ce..9e4b6ac609 100644 --- a/sdks/typescript/tunnel-protocol/src/index.ts +++ b/sdks/typescript/tunnel-protocol/src/index.ts @@ -1,30 +1,38 @@ +import assert from "node:assert" import * as bare from "@bare-ts/lib" const DEFAULT_CONFIG = /* @__PURE__ */ bare.Config({}) export type u16 = number -export type u64 = bigint -export type RequestId = u64 +export type RequestId = ArrayBuffer export function readRequestId(bc: bare.ByteCursor): RequestId { - return bare.readU64(bc) + return bare.readFixedData(bc, 16) } export function writeRequestId(bc: bare.ByteCursor, x: RequestId): void { - bare.writeU64(bc, x) + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) } -export type WebSocketId = u64 +/** + * UUIDv4 + */ +export type MessageId = ArrayBuffer -export function readWebSocketId(bc: bare.ByteCursor): WebSocketId { - return bare.readU64(bc) +export function readMessageId(bc: bare.ByteCursor): MessageId { + return bare.readFixedData(bc, 16) } -export function writeWebSocketId(bc: bare.ByteCursor, x: WebSocketId): void { - bare.writeU64(bc, x) +export function writeMessageId(bc: bare.ByteCursor, x: MessageId): void { + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) } +/** + * UUIDv4 + */ export type Id = string export function readId(bc: bare.ByteCursor): Id { @@ -35,38 +43,10 @@ export function writeId(bc: bare.ByteCursor, x: Id): void { bare.writeString(bc, x) } -export enum StreamFinishReason { - Complete = "Complete", - Abort = "Abort", -} - -export function readStreamFinishReason(bc: bare.ByteCursor): StreamFinishReason { - const offset = bc.offset - const tag = bare.readU8(bc) - switch (tag) { - case 0: - return StreamFinishReason.Complete - case 1: - return StreamFinishReason.Abort - default: { - bc.offset = offset - throw new bare.BareError(offset, "invalid tag") - } - } -} - -export function writeStreamFinishReason(bc: bare.ByteCursor, x: StreamFinishReason): void { - switch (x) { - case StreamFinishReason.Complete: { - bare.writeU8(bc, 0) - break - } - case StreamFinishReason.Abort: { - bare.writeU8(bc, 1) - break - } - } -} +/** + * MARK: Ack + */ +export type Ack = null function read0(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) @@ -103,10 +83,9 @@ function write1(bc: bare.ByteCursor, x: ArrayBuffer | null): void { } /** - * MARK: HTTP Request Forwarding + * MARK: HTTP */ export type ToServerRequestStart = { - readonly requestId: RequestId readonly actorId: Id readonly method: string readonly path: string @@ -117,7 +96,6 @@ export type ToServerRequestStart = { export function readToServerRequestStart(bc: bare.ByteCursor): ToServerRequestStart { return { - requestId: readRequestId(bc), actorId: readId(bc), method: bare.readString(bc), path: bare.readString(bc), @@ -128,7 +106,6 @@ export function readToServerRequestStart(bc: bare.ByteCursor): ToServerRequestSt } export function writeToServerRequestStart(bc: bare.ByteCursor, x: ToServerRequestStart): void { - writeRequestId(bc, x.requestId) writeId(bc, x.actorId) bare.writeString(bc, x.method) bare.writeString(bc, x.path) @@ -138,41 +115,25 @@ export function writeToServerRequestStart(bc: bare.ByteCursor, x: ToServerReques } export type ToServerRequestChunk = { - readonly requestId: RequestId readonly body: ArrayBuffer + readonly finish: boolean } export function readToServerRequestChunk(bc: bare.ByteCursor): ToServerRequestChunk { return { - requestId: readRequestId(bc), body: bare.readData(bc), + finish: bare.readBool(bc), } } export function writeToServerRequestChunk(bc: bare.ByteCursor, x: ToServerRequestChunk): void { - writeRequestId(bc, x.requestId) bare.writeData(bc, x.body) + bare.writeBool(bc, x.finish) } -export type ToServerRequestFinish = { - readonly requestId: RequestId - readonly reason: StreamFinishReason -} - -export function readToServerRequestFinish(bc: bare.ByteCursor): ToServerRequestFinish { - return { - requestId: readRequestId(bc), - reason: readStreamFinishReason(bc), - } -} - -export function writeToServerRequestFinish(bc: bare.ByteCursor, x: ToServerRequestFinish): void { - writeRequestId(bc, x.requestId) - writeStreamFinishReason(bc, x.reason) -} +export type ToServerRequestAbort = null export type ToClientResponseStart = { - readonly requestId: RequestId readonly status: u16 readonly headers: ReadonlyMap readonly body: ArrayBuffer | null @@ -181,7 +142,6 @@ export type ToClientResponseStart = { export function readToClientResponseStart(bc: bare.ByteCursor): ToClientResponseStart { return { - requestId: readRequestId(bc), status: bare.readU16(bc), headers: read0(bc), body: read1(bc), @@ -190,7 +150,6 @@ export function readToClientResponseStart(bc: bare.ByteCursor): ToClientResponse } export function writeToClientResponseStart(bc: bare.ByteCursor, x: ToClientResponseStart): void { - writeRequestId(bc, x.requestId) bare.writeU16(bc, x.status) write0(bc, x.headers) write1(bc, x.body) @@ -198,45 +157,29 @@ export function writeToClientResponseStart(bc: bare.ByteCursor, x: ToClientRespo } export type ToClientResponseChunk = { - readonly requestId: RequestId readonly body: ArrayBuffer + readonly finish: boolean } export function readToClientResponseChunk(bc: bare.ByteCursor): ToClientResponseChunk { return { - requestId: readRequestId(bc), body: bare.readData(bc), + finish: bare.readBool(bc), } } export function writeToClientResponseChunk(bc: bare.ByteCursor, x: ToClientResponseChunk): void { - writeRequestId(bc, x.requestId) bare.writeData(bc, x.body) + bare.writeBool(bc, x.finish) } -export type ToClientResponseFinish = { - readonly requestId: RequestId - readonly reason: StreamFinishReason -} - -export function readToClientResponseFinish(bc: bare.ByteCursor): ToClientResponseFinish { - return { - requestId: readRequestId(bc), - reason: readStreamFinishReason(bc), - } -} - -export function writeToClientResponseFinish(bc: bare.ByteCursor, x: ToClientResponseFinish): void { - writeRequestId(bc, x.requestId) - writeStreamFinishReason(bc, x.reason) -} +export type ToClientResponseAbort = null /** - * MARK: WebSocket Forwarding + * MARK: WebSocket */ export type ToServerWebSocketOpen = { readonly actorId: Id - readonly webSocketId: WebSocketId readonly path: string readonly headers: ReadonlyMap } @@ -244,7 +187,6 @@ export type ToServerWebSocketOpen = { export function readToServerWebSocketOpen(bc: bare.ByteCursor): ToServerWebSocketOpen { return { actorId: readId(bc), - webSocketId: readWebSocketId(bc), path: bare.readString(bc), headers: read0(bc), } @@ -252,27 +194,23 @@ export function readToServerWebSocketOpen(bc: bare.ByteCursor): ToServerWebSocke export function writeToServerWebSocketOpen(bc: bare.ByteCursor, x: ToServerWebSocketOpen): void { writeId(bc, x.actorId) - writeWebSocketId(bc, x.webSocketId) bare.writeString(bc, x.path) write0(bc, x.headers) } export type ToServerWebSocketMessage = { - readonly webSocketId: WebSocketId readonly data: ArrayBuffer readonly binary: boolean } export function readToServerWebSocketMessage(bc: bare.ByteCursor): ToServerWebSocketMessage { return { - webSocketId: readWebSocketId(bc), data: bare.readData(bc), binary: bare.readBool(bc), } } export function writeToServerWebSocketMessage(bc: bare.ByteCursor, x: ToServerWebSocketMessage): void { - writeWebSocketId(bc, x.webSocketId) bare.writeData(bc, x.data) bare.writeBool(bc, x.binary) } @@ -300,75 +238,54 @@ function write3(bc: bare.ByteCursor, x: string | null): void { } export type ToServerWebSocketClose = { - readonly webSocketId: WebSocketId readonly code: u16 | null readonly reason: string | null } export function readToServerWebSocketClose(bc: bare.ByteCursor): ToServerWebSocketClose { return { - webSocketId: readWebSocketId(bc), code: read2(bc), reason: read3(bc), } } export function writeToServerWebSocketClose(bc: bare.ByteCursor, x: ToServerWebSocketClose): void { - writeWebSocketId(bc, x.webSocketId) write2(bc, x.code) write3(bc, x.reason) } -export type ToClientWebSocketOpen = { - readonly webSocketId: WebSocketId -} - -export function readToClientWebSocketOpen(bc: bare.ByteCursor): ToClientWebSocketOpen { - return { - webSocketId: readWebSocketId(bc), - } -} - -export function writeToClientWebSocketOpen(bc: bare.ByteCursor, x: ToClientWebSocketOpen): void { - writeWebSocketId(bc, x.webSocketId) -} +export type ToClientWebSocketOpen = null export type ToClientWebSocketMessage = { - readonly webSocketId: WebSocketId readonly data: ArrayBuffer readonly binary: boolean } export function readToClientWebSocketMessage(bc: bare.ByteCursor): ToClientWebSocketMessage { return { - webSocketId: readWebSocketId(bc), data: bare.readData(bc), binary: bare.readBool(bc), } } export function writeToClientWebSocketMessage(bc: bare.ByteCursor, x: ToClientWebSocketMessage): void { - writeWebSocketId(bc, x.webSocketId) bare.writeData(bc, x.data) bare.writeBool(bc, x.binary) } export type ToClientWebSocketClose = { - readonly webSocketId: WebSocketId readonly code: u16 | null readonly reason: string | null } export function readToClientWebSocketClose(bc: bare.ByteCursor): ToClientWebSocketClose { return { - webSocketId: readWebSocketId(bc), code: read2(bc), reason: read3(bc), } } export function writeToClientWebSocketClose(bc: bare.ByteCursor, x: ToClientWebSocketClose): void { - writeWebSocketId(bc, x.webSocketId) write2(bc, x.code) write3(bc, x.reason) } @@ -376,16 +293,17 @@ export function writeToClientWebSocketClose(bc: bare.ByteCursor, x: ToClientWebS /** * MARK: Message */ -export type MessageBody = +export type MessageKind = + | { readonly tag: "Ack"; readonly val: Ack } /** * HTTP */ | { readonly tag: "ToServerRequestStart"; readonly val: ToServerRequestStart } | { readonly tag: "ToServerRequestChunk"; readonly val: ToServerRequestChunk } - | { readonly tag: "ToServerRequestFinish"; readonly val: ToServerRequestFinish } + | { readonly tag: "ToServerRequestAbort"; readonly val: ToServerRequestAbort } | { readonly tag: "ToClientResponseStart"; readonly val: ToClientResponseStart } | { readonly tag: "ToClientResponseChunk"; readonly val: ToClientResponseChunk } - | { readonly tag: "ToClientResponseFinish"; readonly val: ToClientResponseFinish } + | { readonly tag: "ToClientResponseAbort"; readonly val: ToClientResponseAbort } /** * WebSocket */ @@ -396,33 +314,35 @@ export type MessageBody = | { readonly tag: "ToClientWebSocketMessage"; readonly val: ToClientWebSocketMessage } | { readonly tag: "ToClientWebSocketClose"; readonly val: ToClientWebSocketClose } -export function readMessageBody(bc: bare.ByteCursor): MessageBody { +export function readMessageKind(bc: bare.ByteCursor): MessageKind { const offset = bc.offset const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "ToServerRequestStart", val: readToServerRequestStart(bc) } + return { tag: "Ack", val: null } case 1: - return { tag: "ToServerRequestChunk", val: readToServerRequestChunk(bc) } + return { tag: "ToServerRequestStart", val: readToServerRequestStart(bc) } case 2: - return { tag: "ToServerRequestFinish", val: readToServerRequestFinish(bc) } + return { tag: "ToServerRequestChunk", val: readToServerRequestChunk(bc) } case 3: - return { tag: "ToClientResponseStart", val: readToClientResponseStart(bc) } + return { tag: "ToServerRequestAbort", val: null } case 4: - return { tag: "ToClientResponseChunk", val: readToClientResponseChunk(bc) } + return { tag: "ToClientResponseStart", val: readToClientResponseStart(bc) } case 5: - return { tag: "ToClientResponseFinish", val: readToClientResponseFinish(bc) } + return { tag: "ToClientResponseChunk", val: readToClientResponseChunk(bc) } case 6: - return { tag: "ToServerWebSocketOpen", val: readToServerWebSocketOpen(bc) } + return { tag: "ToClientResponseAbort", val: null } case 7: - return { tag: "ToServerWebSocketMessage", val: readToServerWebSocketMessage(bc) } + return { tag: "ToServerWebSocketOpen", val: readToServerWebSocketOpen(bc) } case 8: - return { tag: "ToServerWebSocketClose", val: readToServerWebSocketClose(bc) } + return { tag: "ToServerWebSocketMessage", val: readToServerWebSocketMessage(bc) } case 9: - return { tag: "ToClientWebSocketOpen", val: readToClientWebSocketOpen(bc) } + return { tag: "ToServerWebSocketClose", val: readToServerWebSocketClose(bc) } case 10: - return { tag: "ToClientWebSocketMessage", val: readToClientWebSocketMessage(bc) } + return { tag: "ToClientWebSocketOpen", val: null } case 11: + return { tag: "ToClientWebSocketMessage", val: readToClientWebSocketMessage(bc) } + case 12: return { tag: "ToClientWebSocketClose", val: readToClientWebSocketClose(bc) } default: { bc.offset = offset @@ -431,65 +351,66 @@ export function readMessageBody(bc: bare.ByteCursor): MessageBody { } } -export function writeMessageBody(bc: bare.ByteCursor, x: MessageBody): void { +export function writeMessageKind(bc: bare.ByteCursor, x: MessageKind): void { switch (x.tag) { - case "ToServerRequestStart": { + case "Ack": { bare.writeU8(bc, 0) + break + } + case "ToServerRequestStart": { + bare.writeU8(bc, 1) writeToServerRequestStart(bc, x.val) break } case "ToServerRequestChunk": { - bare.writeU8(bc, 1) + bare.writeU8(bc, 2) writeToServerRequestChunk(bc, x.val) break } - case "ToServerRequestFinish": { - bare.writeU8(bc, 2) - writeToServerRequestFinish(bc, x.val) + case "ToServerRequestAbort": { + bare.writeU8(bc, 3) break } case "ToClientResponseStart": { - bare.writeU8(bc, 3) + bare.writeU8(bc, 4) writeToClientResponseStart(bc, x.val) break } case "ToClientResponseChunk": { - bare.writeU8(bc, 4) + bare.writeU8(bc, 5) writeToClientResponseChunk(bc, x.val) break } - case "ToClientResponseFinish": { - bare.writeU8(bc, 5) - writeToClientResponseFinish(bc, x.val) + case "ToClientResponseAbort": { + bare.writeU8(bc, 6) break } case "ToServerWebSocketOpen": { - bare.writeU8(bc, 6) + bare.writeU8(bc, 7) writeToServerWebSocketOpen(bc, x.val) break } case "ToServerWebSocketMessage": { - bare.writeU8(bc, 7) + bare.writeU8(bc, 8) writeToServerWebSocketMessage(bc, x.val) break } case "ToServerWebSocketClose": { - bare.writeU8(bc, 8) + bare.writeU8(bc, 9) writeToServerWebSocketClose(bc, x.val) break } case "ToClientWebSocketOpen": { - bare.writeU8(bc, 9) - writeToClientWebSocketOpen(bc, x.val) + bare.writeU8(bc, 10) break } case "ToClientWebSocketMessage": { - bare.writeU8(bc, 10) + bare.writeU8(bc, 11) writeToClientWebSocketMessage(bc, x.val) break } case "ToClientWebSocketClose": { - bare.writeU8(bc, 11) + bare.writeU8(bc, 12) writeToClientWebSocketClose(bc, x.val) break } @@ -497,35 +418,89 @@ export function writeMessageBody(bc: bare.ByteCursor, x: MessageBody): void { } /** - * Main tunnel message + * MARK: Message sent over tunnel WebSocket */ -export type TunnelMessage = { - readonly body: MessageBody +export type RunnerMessage = { + readonly requestId: RequestId + readonly messageId: MessageId + readonly messageKind: MessageKind } -export function readTunnelMessage(bc: bare.ByteCursor): TunnelMessage { +export function readRunnerMessage(bc: bare.ByteCursor): RunnerMessage { return { - body: readMessageBody(bc), + requestId: readRequestId(bc), + messageId: readMessageId(bc), + messageKind: readMessageKind(bc), } } -export function writeTunnelMessage(bc: bare.ByteCursor, x: TunnelMessage): void { - writeMessageBody(bc, x.body) +export function writeRunnerMessage(bc: bare.ByteCursor, x: RunnerMessage): void { + writeRequestId(bc, x.requestId) + writeMessageId(bc, x.messageId) + writeMessageKind(bc, x.messageKind) +} + +export function encodeRunnerMessage(x: RunnerMessage, config?: Partial): Uint8Array { + const fullConfig = config != null ? bare.Config(config) : DEFAULT_CONFIG + const bc = new bare.ByteCursor( + new Uint8Array(fullConfig.initialBufferLength), + fullConfig, + ) + writeRunnerMessage(bc, x) + return new Uint8Array(bc.view.buffer, bc.view.byteOffset, bc.offset) +} + +export function decodeRunnerMessage(bytes: Uint8Array): RunnerMessage { + const bc = new bare.ByteCursor(bytes, DEFAULT_CONFIG) + const result = readRunnerMessage(bc) + if (bc.offset < bc.view.byteLength) { + throw new bare.BareError(bc.offset, "remaining bytes") + } + return result +} + +/** + * MARK: Message sent over UPS + */ +export type PubSubMessage = { + readonly requestId: RequestId + readonly messageId: MessageId + /** + * Subject to send replies to. Only sent when opening a new request from gateway -> runner. + */ + readonly replyTo: string | null + readonly messageKind: MessageKind +} + +export function readPubSubMessage(bc: bare.ByteCursor): PubSubMessage { + return { + requestId: readRequestId(bc), + messageId: readMessageId(bc), + replyTo: read3(bc), + messageKind: readMessageKind(bc), + } +} + +export function writePubSubMessage(bc: bare.ByteCursor, x: PubSubMessage): void { + writeRequestId(bc, x.requestId) + writeMessageId(bc, x.messageId) + write3(bc, x.replyTo) + writeMessageKind(bc, x.messageKind) } -export function encodeTunnelMessage(x: TunnelMessage, config?: Partial): Uint8Array { +export function encodePubSubMessage(x: PubSubMessage, config?: Partial): Uint8Array { const fullConfig = config != null ? bare.Config(config) : DEFAULT_CONFIG const bc = new bare.ByteCursor( new Uint8Array(fullConfig.initialBufferLength), fullConfig, ) - writeTunnelMessage(bc, x) + writePubSubMessage(bc, x) return new Uint8Array(bc.view.buffer, bc.view.byteOffset, bc.offset) } -export function decodeTunnelMessage(bytes: Uint8Array): TunnelMessage { +export function decodePubSubMessage(bytes: Uint8Array): PubSubMessage { const bc = new bare.ByteCursor(bytes, DEFAULT_CONFIG) - const result = readTunnelMessage(bc) + const result = readPubSubMessage(bc) if (bc.offset < bc.view.byteLength) { throw new bare.BareError(bc.offset, "remaining bytes") } From 74e5db56c2c121c60091e27807b8a85f1f56d07c Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 8 Sep 2025 17:58:08 -0700 Subject: [PATCH 2/8] chore(ups): add postgres auto-reconnect --- Cargo.lock | 1 + packages/common/test-deps-docker/src/lib.rs | 36 +++ packages/common/universalpubsub/Cargo.toml | 1 + .../src/driver/postgres/mod.rs | 222 ++++++++++++++---- .../common/universalpubsub/tests/reconnect.rs | 217 +++++++++++++++++ 5 files changed, 430 insertions(+), 47 deletions(-) create mode 100644 packages/common/universalpubsub/tests/reconnect.rs diff --git a/Cargo.lock b/Cargo.lock index 800d4ca464..44bd7cd6e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6338,6 +6338,7 @@ dependencies = [ "rivet-error", "rivet-test-deps-docker", "rivet-ups-protocol", + "rivet-util", "serde", "serde_json", "sha2", diff --git a/packages/common/test-deps-docker/src/lib.rs b/packages/common/test-deps-docker/src/lib.rs index efa46a3152..2ed73f8600 100644 --- a/packages/common/test-deps-docker/src/lib.rs +++ b/packages/common/test-deps-docker/src/lib.rs @@ -73,6 +73,42 @@ impl DockerRunConfig { Ok(true) } + pub async fn restart(&self) -> Result<()> { + let container_id = self + .container_id + .as_ref() + .ok_or_else(|| anyhow!("No container ID found, container not started"))?; + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "restarting docker container" + ); + + let output = Command::new("docker") + .arg("restart") + .arg(container_id) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "Failed to restart container {}: {}", + self.container_name, + stderr + ); + } + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "container restarted successfully" + ); + + Ok(()) + } + pub fn container_id(&self) -> Option<&str> { self.container_id.as_deref() } diff --git a/packages/common/universalpubsub/Cargo.toml b/packages/common/universalpubsub/Cargo.toml index 3cdb8509b6..4ffddeb77c 100644 --- a/packages/common/universalpubsub/Cargo.toml +++ b/packages/common/universalpubsub/Cargo.toml @@ -14,6 +14,7 @@ deadpool-postgres.workspace = true futures-util.workspace = true rivet-error.workspace = true rivet-ups-protocol.workspace = true +rivet-util.workspace = true serde_json.workspace = true versioned-data-util.workspace = true serde.workspace = true diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index c2df9a36b4..21cf96f5f5 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use anyhow::*; use async_trait::async_trait; @@ -8,6 +8,8 @@ use base64::Engine; use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64; use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime}; use futures_util::future::poll_fn; +use rivet_util::backoff::Backoff; +use tokio::sync::{Mutex, broadcast}; use tokio_postgres::{AsyncMessage, NoTls}; use tracing::Instrument; @@ -17,13 +19,13 @@ use crate::pubsub::DriverOutput; #[derive(Clone)] struct Subscription { // Channel to send messages to this subscription - tx: tokio::sync::broadcast::Sender>, + tx: broadcast::Sender>, // Cancellation token shared by all subscribers of this subject token: tokio_util::sync::CancellationToken, } impl Subscription { - fn new(tx: tokio::sync::broadcast::Sender>) -> Self { + fn new(tx: broadcast::Sender>) -> Self { let token = tokio_util::sync::CancellationToken::new(); Self { tx, token } } @@ -48,8 +50,9 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize = #[derive(Clone)] pub struct PostgresDriver { pool: Arc, - client: Arc, + client: Arc>>>, subscriptions: Arc>>, + client_ready: tokio::sync::watch::Receiver, } impl PostgresDriver { @@ -76,48 +79,168 @@ impl PostgresDriver { let subscriptions: Arc>> = Arc::new(Mutex::new(HashMap::new())); - let subscriptions2 = subscriptions.clone(); + let client: Arc>>> = Arc::new(Mutex::new(None)); - let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?; - tokio::spawn(async move { - // NOTE: This loop will stop automatically when client is dropped - loop { - match poll_fn(|cx| conn.poll_message(cx)).await { - Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { - if let Some(sub) = - subscriptions2.lock().unwrap().get(note.channel()).cloned() - { - let bytes = match BASE64.decode(note.payload()) { - std::result::Result::Ok(b) => b, - std::result::Result::Err(err) => { - tracing::error!(?err, "failed decoding base64"); - break; - } - }; - let _ = sub.tx.send(bytes); - } - } - Some(std::result::Result::Ok(_)) => { - // Ignore other async messages + // Create channel for client ready notifications + let (ready_tx, client_ready) = tokio::sync::watch::channel(false); + + // Spawn connection lifecycle task + tokio::spawn(Self::spawn_connection_lifecycle( + conn_str.clone(), + subscriptions.clone(), + client.clone(), + ready_tx, + )); + + let driver = Self { + pool: Arc::new(pool), + client, + subscriptions, + client_ready, + }; + + // Wait for initial connection to be established + driver.wait_for_client().await?; + + Ok(driver) + } + + /// Manages the connection lifecycle with automatic reconnection + async fn spawn_connection_lifecycle( + conn_str: String, + subscriptions: Arc>>, + client: Arc>>>, + ready_tx: tokio::sync::watch::Sender, + ) { + let mut backoff = Backoff::new(8, None, 1_000, 1_000); + + loop { + match tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await { + Result::Ok((new_client, conn)) => { + tracing::info!("postgres listen connection established"); + // Reset backoff on successful connection + backoff = Backoff::new(8, None, 1_000, 1_000); + + let new_client = Arc::new(new_client); + + // Update the client reference immediately + *client.lock().await = Some(new_client.clone()); + // Notify that client is ready + let _ = ready_tx.send(true); + + // Get channels to re-subscribe to + let channels: Vec = + subscriptions.lock().await.keys().cloned().collect(); + let needs_resubscribe = !channels.is_empty(); + + if needs_resubscribe { + tracing::debug!( + ?channels, + "will re-subscribe to channels after connection starts" + ); } - Some(std::result::Result::Err(err)) => { - tracing::error!(?err, "async postgres error"); - break; + + // Spawn a task to re-subscribe after a short delay + if needs_resubscribe { + let client_for_resub = new_client.clone(); + let channels_clone = channels.clone(); + tokio::spawn(async move { + tracing::debug!( + ?channels_clone, + "re-subscribing to channels after reconnection" + ); + for channel in &channels_clone { + if let Result::Err(e) = client_for_resub + .execute(&format!("LISTEN \"{}\"", channel), &[]) + .await + { + tracing::error!(?e, %channel, "failed to re-subscribe to channel"); + } else { + tracing::debug!(%channel, "successfully re-subscribed to channel"); + } + } + }); } - None => { - tracing::debug!("async postgres connection closed"); - break; + + // Poll the connection until it closes + Self::poll_connection(conn, subscriptions.clone()).await; + + // Clear the client reference on disconnect + *client.lock().await = None; + // Notify that client is disconnected + let _ = ready_tx.send(false); + } + Result::Err(e) => { + tracing::error!(?e, "failed to connect to postgres, retrying"); + backoff.tick().await; + } + } + } + } + + /// Polls the connection for notifications until it closes or errors + async fn poll_connection( + mut conn: tokio_postgres::Connection< + tokio_postgres::Socket, + tokio_postgres::tls::NoTlsStream, + >, + subscriptions: Arc>>, + ) { + loop { + match poll_fn(|cx| conn.poll_message(cx)).await { + Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => { + tracing::trace!(channel = %note.channel(), "received notification"); + if let Some(sub) = subscriptions.lock().await.get(note.channel()).cloned() { + let bytes = match BASE64.decode(note.payload()) { + std::result::Result::Ok(b) => b, + std::result::Result::Err(err) => { + tracing::error!(?err, "failed decoding base64"); + continue; + } + }; + tracing::trace!(channel = %note.channel(), bytes_len = bytes.len(), "sending to broadcast channel"); + let _ = sub.tx.send(bytes); + } else { + tracing::warn!(channel = %note.channel(), "received notification for unknown channel"); } } + Some(std::result::Result::Ok(_)) => { + // Ignore other async messages + } + Some(std::result::Result::Err(err)) => { + tracing::error!(?err, "postgres connection error, reconnecting"); + break; // Exit loop to reconnect + } + None => { + tracing::warn!("postgres connection closed, reconnecting"); + break; // Exit loop to reconnect + } } - tracing::debug!("listen connection closed"); - }); + } + } - Ok(Self { - pool: Arc::new(pool), - client: Arc::new(client), - subscriptions, + /// Wait for the client to be connected + async fn wait_for_client(&self) -> Result> { + let mut ready_rx = self.client_ready.clone(); + tokio::time::timeout(tokio::time::Duration::from_secs(5), async { + loop { + // Subscribe to changed before attempting to access client + let changed_fut = ready_rx.changed(); + + // Check if client is already available + if let Some(client) = self.client.lock().await.clone() { + return Ok(client); + } + + // Wait for change, will return client if exists on next iteration + changed_fut + .await + .map_err(|_| anyhow!("connection lifecycle task ended"))?; + tracing::debug!("client does not exist immediately after receive ready"); + } }) + .await + .map_err(|_| anyhow!("timeout waiting for postgres client connection"))? } fn hash_subject(&self, subject: &str) -> String { @@ -147,7 +270,7 @@ impl PubSubDriver for PostgresDriver { // Check if we already have a subscription for this channel let (rx, drop_guard) = - if let Some(existing_sub) = self.subscriptions.lock().unwrap().get(&hashed).cloned() { + if let Some(existing_sub) = self.subscriptions.lock().await.get(&hashed).cloned() { // Reuse the existing broadcast channel let rx = existing_sub.tx.subscribe(); let drop_guard = existing_sub.token.clone().drop_guard(); @@ -160,13 +283,15 @@ impl PubSubDriver for PostgresDriver { // Register subscription self.subscriptions .lock() - .unwrap() + .await .insert(hashed.clone(), subscription.clone()); // Execute LISTEN command on the async client (for receiving notifications) // This only needs to be done once per channel + // Wait for client to be connected with retry logic + let client = self.wait_for_client().await?; let span = tracing::trace_span!("pg_listen"); - self.client + client .execute(&format!("LISTEN \"{hashed}\""), &[]) .instrument(span) .await?; @@ -179,13 +304,16 @@ impl PubSubDriver for PostgresDriver { tokio::spawn(async move { token_clone.cancelled().await; if tx_clone.receiver_count() == 0 { - let sql = format!("UNLISTEN \"{}\"", hashed_clone); - if let Err(err) = driver.client.execute(sql.as_str(), &[]).await { - tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel"); - } else { - tracing::trace!(%hashed_clone, "unlistened channel"); + let client = driver.client.lock().await.clone(); + if let Some(client) = client { + let sql = format!("UNLISTEN \"{}\"", hashed_clone); + if let Err(err) = client.execute(sql.as_str(), &[]).await { + tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel"); + } else { + tracing::trace!(%hashed_clone, "unlistened channel"); + } } - driver.subscriptions.lock().unwrap().remove(&hashed_clone); + driver.subscriptions.lock().await.remove(&hashed_clone); } }); diff --git a/packages/common/universalpubsub/tests/reconnect.rs b/packages/common/universalpubsub/tests/reconnect.rs new file mode 100644 index 0000000000..fa98767e95 --- /dev/null +++ b/packages/common/universalpubsub/tests/reconnect.rs @@ -0,0 +1,217 @@ +use anyhow::*; +use rivet_test_deps_docker::{TestDatabase, TestPubSub}; +use std::{sync::Arc, time::Duration}; +use universalpubsub::{NextOutput, PubSub, PublishOpts}; +use uuid::Uuid; + +fn setup_logging() { + let _ = tracing_subscriber::fmt() + .with_env_filter("debug") + .with_ansi(false) + .with_test_writer() + .try_init(); +} + +#[tokio::test] +async fn test_nats_driver_with_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (pubsub_config, docker_config) = TestPubSub::Nats.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + + let rivet_config::config::PubSub::Nats(nats) = pubsub_config else { + unreachable!(); + }; + + use std::str::FromStr; + let server_addrs = nats + .addresses + .iter() + .map(|addr| format!("nats://{addr}")) + .map(|url| async_nats::ServerAddr::from_str(url.as_ref())) + .collect::, _>>() + .unwrap(); + + let driver = universalpubsub::driver::nats::NatsDriver::connect( + async_nats::ConnectOptions::new(), + &server_addrs[..], + ) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); + + test_reconnect_inner(&pubsub, &docker).await; +} + +#[tokio::test] +async fn test_nats_driver_without_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (pubsub_config, docker_config) = TestPubSub::Nats.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + + let rivet_config::config::PubSub::Nats(nats) = pubsub_config else { + unreachable!(); + }; + + use std::str::FromStr; + let server_addrs = nats + .addresses + .iter() + .map(|addr| format!("nats://{addr}")) + .map(|url| async_nats::ServerAddr::from_str(url.as_ref())) + .collect::, _>>() + .unwrap(); + + let driver = universalpubsub::driver::nats::NatsDriver::connect( + async_nats::ConnectOptions::new(), + &server_addrs[..], + ) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); + + test_reconnect_inner(&pubsub, &docker).await; +} + +#[tokio::test] +async fn test_postgres_driver_with_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (db_config, docker_config) = TestDatabase::Postgres.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + let rivet_config::config::Database::Postgres(pg) = db_config else { + unreachable!(); + }; + let url = pg.url.read().clone(); + + let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, true) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); + + test_reconnect_inner(&pubsub, &docker).await; +} + +#[tokio::test] +async fn test_postgres_driver_without_memory_reconnect() { + setup_logging(); + + let test_id = Uuid::new_v4(); + let (db_config, docker_config) = TestDatabase::Postgres.config(test_id, 1).await.unwrap(); + let mut docker = docker_config.unwrap(); + docker.start().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + let rivet_config::config::Database::Postgres(pg) = db_config else { + unreachable!(); + }; + let url = pg.url.read().clone(); + + let driver = universalpubsub::driver::postgres::PostgresDriver::connect(url, false) + .await + .unwrap(); + let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); + + test_reconnect_inner(&pubsub, &docker).await; +} + +async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { + tracing::info!("testing reconnect functionality"); + + // Open subscription + let mut subscriber = pubsub.subscribe("test.reconnect").await.unwrap(); + tracing::info!("opened initial subscription"); + + // Test publish/receive message before restart + let message_before = b"message before restart"; + pubsub + .publish("test.reconnect", message_before, PublishOpts::broadcast()) + .await + .unwrap(); + pubsub.flush().await.unwrap(); + + match subscriber.next().await.unwrap() { + NextOutput::Message(msg) => { + assert_eq!( + msg.payload, message_before, + "message before restart should match" + ); + tracing::info!("received message before restart"); + } + NextOutput::Unsubscribed => { + panic!("unexpected unsubscribe before restart"); + } + } + + // Restart container + tracing::info!("restarting docker container"); + docker.restart().await.unwrap(); + + // Give the service time to come back up + tokio::time::sleep(Duration::from_secs(3)).await; + tracing::info!("docker container restarted"); + + // Test publish/receive message after restart + let message_after = b"message after restart"; + + // Retry logic for publish after restart since connection might need to reconnect + let mut retries = 0; + const MAX_RETRIES: u32 = 10; + loop { + match pubsub + .publish("test.reconnect", message_after, PublishOpts::broadcast()) + .await + { + Result::Ok(_) => { + tracing::info!("published message after restart"); + break; + } + Result::Err(e) if retries < MAX_RETRIES => { + retries += 1; + tracing::debug!(?e, retries, "failed to publish after restart, retrying"); + tokio::time::sleep(Duration::from_millis(500)).await; + } + Result::Err(e) => { + panic!("failed to publish after {} retries: {}", MAX_RETRIES, e); + } + } + } + + pubsub.flush().await.unwrap(); + + // Try to receive with timeout to handle reconnection delays + let receive_timeout = Duration::from_secs(10); + let receive_result = tokio::time::timeout(receive_timeout, subscriber.next()).await; + + match receive_result { + Result::Ok(Result::Ok(NextOutput::Message(msg))) => { + assert_eq!( + msg.payload, message_after, + "message after restart should match" + ); + tracing::info!("received message after restart - reconnection successful"); + } + Result::Ok(Result::Ok(NextOutput::Unsubscribed)) => { + panic!("unexpected unsubscribe after restart"); + } + Result::Ok(Result::Err(e)) => { + panic!("error receiving message after restart: {}", e); + } + Result::Err(_) => { + panic!("timeout receiving message after restart"); + } + } + + tracing::info!("reconnect test completed successfully"); +} From e29523a2c49904f573638386a1038549be9e368d Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 8 Sep 2025 20:18:37 -0700 Subject: [PATCH 3/8] chore(ups): handle edge cases with postgres listen/unlisten/notify when disconnected/reconnecting --- out/errors/ups.publish_failed.json | 5 + packages/common/test-deps-docker/src/lib.rs | 72 +++++++ .../src/driver/postgres/mod.rs | 166 +++++++++----- packages/common/universalpubsub/src/errors.rs | 2 + packages/common/universalpubsub/src/pubsub.rs | 31 ++- .../common/universalpubsub/tests/reconnect.rs | 202 +++++++++++++++--- 6 files changed, 395 insertions(+), 83 deletions(-) create mode 100644 out/errors/ups.publish_failed.json diff --git a/out/errors/ups.publish_failed.json b/out/errors/ups.publish_failed.json new file mode 100644 index 0000000000..b4c1ac0b42 --- /dev/null +++ b/out/errors/ups.publish_failed.json @@ -0,0 +1,5 @@ +{ + "code": "publish_failed", + "group": "ups", + "message": "Failed to publish message after retries" +} \ No newline at end of file diff --git a/packages/common/test-deps-docker/src/lib.rs b/packages/common/test-deps-docker/src/lib.rs index 2ed73f8600..644e57ad53 100644 --- a/packages/common/test-deps-docker/src/lib.rs +++ b/packages/common/test-deps-docker/src/lib.rs @@ -109,6 +109,78 @@ impl DockerRunConfig { Ok(()) } + pub async fn stop_container(&self) -> Result<()> { + let container_id = self + .container_id + .as_ref() + .ok_or_else(|| anyhow!("No container ID found, container not started"))?; + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "stopping docker container" + ); + + let output = Command::new("docker") + .arg("stop") + .arg(container_id) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "Failed to stop container {}: {}", + self.container_name, + stderr + ); + } + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "container stopped successfully" + ); + + Ok(()) + } + + pub async fn start_container(&self) -> Result<()> { + let container_id = self + .container_id + .as_ref() + .ok_or_else(|| anyhow!("No container ID found, container not started"))?; + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "starting docker container" + ); + + let output = Command::new("docker") + .arg("start") + .arg(container_id) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "Failed to start container {}: {}", + self.container_name, + stderr + ); + } + + tracing::debug!( + container_name = %self.container_name, + container_id = %container_id, + "container started successfully" + ); + + Ok(()) + } + pub fn container_id(&self) -> Option<&str> { self.container_id.as_deref() } diff --git a/packages/common/universalpubsub/src/driver/postgres/mod.rs b/packages/common/universalpubsub/src/driver/postgres/mod.rs index 21cf96f5f5..3115a7ac72 100644 --- a/packages/common/universalpubsub/src/driver/postgres/mod.rs +++ b/packages/common/universalpubsub/src/driver/postgres/mod.rs @@ -112,21 +112,23 @@ impl PostgresDriver { client: Arc>>>, ready_tx: tokio::sync::watch::Sender, ) { - let mut backoff = Backoff::new(8, None, 1_000, 1_000); + let mut backoff = Backoff::default(); loop { match tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await { Result::Ok((new_client, conn)) => { tracing::info!("postgres listen connection established"); // Reset backoff on successful connection - backoff = Backoff::new(8, None, 1_000, 1_000); + backoff = Backoff::default(); let new_client = Arc::new(new_client); - // Update the client reference immediately - *client.lock().await = Some(new_client.clone()); - // Notify that client is ready - let _ = ready_tx.send(true); + // Spawn the polling task immediately + // This must be done before any operations on the client + let subscriptions_clone = subscriptions.clone(); + let poll_handle = tokio::spawn(async move { + Self::poll_connection(conn, subscriptions_clone).await; + }); // Get channels to re-subscribe to let channels: Vec = @@ -135,38 +137,41 @@ impl PostgresDriver { if needs_resubscribe { tracing::debug!( - ?channels, + channels=?channels.len(), "will re-subscribe to channels after connection starts" ); } - // Spawn a task to re-subscribe after a short delay + // Re-subscribe to channels if needs_resubscribe { - let client_for_resub = new_client.clone(); - let channels_clone = channels.clone(); - tokio::spawn(async move { - tracing::debug!( - ?channels_clone, - "re-subscribing to channels after reconnection" - ); - for channel in &channels_clone { - if let Result::Err(e) = client_for_resub - .execute(&format!("LISTEN \"{}\"", channel), &[]) - .await - { - tracing::error!(?e, %channel, "failed to re-subscribe to channel"); - } else { - tracing::debug!(%channel, "successfully re-subscribed to channel"); - } + tracing::debug!( + channels=?channels.len(), + "re-subscribing to channels after reconnection" + ); + for channel in &channels { + tracing::info!(?channel, "re-subscribing to channel"); + if let Result::Err(e) = new_client + .execute(&format!("LISTEN \"{}\"", channel), &[]) + .await + { + tracing::error!(?e, %channel, "failed to re-subscribe to channel"); + } else { + tracing::debug!(%channel, "successfully re-subscribed to channel"); } - }); + } } - // Poll the connection until it closes - Self::poll_connection(conn, subscriptions.clone()).await; + // Update the client reference and signal ready + // Do this AFTER re-subscribing to ensure LISTEN is complete + *client.lock().await = Some(new_client.clone()); + let _ = ready_tx.send(true); + + // Wait for the polling task to complete (when the connection closes) + let _ = poll_handle.await; // Clear the client reference on disconnect *client.lock().await = None; + // Notify that client is disconnected let _ = ready_tx.send(false); } @@ -208,12 +213,12 @@ impl PostgresDriver { // Ignore other async messages } Some(std::result::Result::Err(err)) => { - tracing::error!(?err, "postgres connection error, reconnecting"); - break; // Exit loop to reconnect + tracing::error!(?err, "postgres connection error"); + break; } None => { - tracing::warn!("postgres connection closed, reconnecting"); - break; // Exit loop to reconnect + tracing::warn!("postgres connection closed"); + break; } } } @@ -224,19 +229,16 @@ impl PostgresDriver { let mut ready_rx = self.client_ready.clone(); tokio::time::timeout(tokio::time::Duration::from_secs(5), async { loop { - // Subscribe to changed before attempting to access client - let changed_fut = ready_rx.changed(); - // Check if client is already available if let Some(client) = self.client.lock().await.clone() { return Ok(client); } - // Wait for change, will return client if exists on next iteration - changed_fut + // Wait for the ready signal to change + ready_rx + .changed() .await .map_err(|_| anyhow!("connection lifecycle task ended"))?; - tracing::debug!("client does not exist immediately after receive ready"); } }) .await @@ -288,13 +290,25 @@ impl PubSubDriver for PostgresDriver { // Execute LISTEN command on the async client (for receiving notifications) // This only needs to be done once per channel - // Wait for client to be connected with retry logic - let client = self.wait_for_client().await?; - let span = tracing::trace_span!("pg_listen"); - client - .execute(&format!("LISTEN \"{hashed}\""), &[]) - .instrument(span) - .await?; + // Try to LISTEN if client is available, but don't fail if disconnected + // The reconnection logic will handle re-subscribing + if let Some(client) = self.client.lock().await.clone() { + let span = tracing::trace_span!("pg_listen"); + match client + .execute(&format!("LISTEN \"{hashed}\""), &[]) + .instrument(span) + .await + { + Result::Ok(_) => { + tracing::debug!(%hashed, "successfully subscribed to channel"); + } + Result::Err(e) => { + tracing::warn!(?e, %hashed, "failed to LISTEN, will retry on reconnection"); + } + } + } else { + tracing::debug!(%hashed, "client not connected, will LISTEN on reconnection"); + } // Spawn a single cleanup task for this subscription waiting on its token let driver = self.clone(); @@ -333,14 +347,66 @@ impl PubSubDriver for PostgresDriver { // Encode payload to base64 and send NOTIFY let encoded = BASE64.encode(payload); - let conn = self.pool.get().await?; let hashed = self.hash_subject(subject); - let span = tracing::trace_span!("pg_notify"); - conn.execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[]) - .instrument(span) - .await?; - Ok(()) + tracing::debug!("attempting to get connection for publish"); + + // Wait for listen connection to be ready first if this channel has subscribers + // This ensures that if we're reconnecting, the LISTEN is re-registered before NOTIFY + if self.subscriptions.lock().await.contains_key(&hashed) { + self.wait_for_client().await?; + } + + // Retry getting a connection from the pool with backoff in case the connection is + // currently disconnected + let mut backoff = Backoff::default(); + let mut last_error = None; + + loop { + match self.pool.get().await { + Result::Ok(conn) => { + // Test the connection with a simple query before using it + match conn.execute("SELECT 1", &[]).await { + Result::Ok(_) => { + // Connection is good, use it for NOTIFY + let span = tracing::trace_span!("pg_notify"); + match conn + .execute(&format!("NOTIFY \"{hashed}\", '{encoded}'"), &[]) + .instrument(span) + .await + { + Result::Ok(_) => return Ok(()), + Result::Err(e) => { + tracing::debug!( + ?e, + "NOTIFY failed, retrying with new connection" + ); + last_error = Some(e.into()); + } + } + } + Result::Err(e) => { + tracing::debug!( + ?e, + "connection test failed, retrying with new connection" + ); + last_error = Some(e.into()); + } + } + } + Result::Err(e) => { + tracing::debug!(?e, "failed to get connection from pool, retrying"); + last_error = Some(e.into()); + } + } + + // Check if we should continue retrying + if !backoff.tick().await { + return Err( + last_error.unwrap_or_else(|| anyhow!("failed to publish after retries")) + ); + } + } } async fn flush(&self) -> Result<()> { diff --git a/packages/common/universalpubsub/src/errors.rs b/packages/common/universalpubsub/src/errors.rs index afab64db4a..69849d4592 100644 --- a/packages/common/universalpubsub/src/errors.rs +++ b/packages/common/universalpubsub/src/errors.rs @@ -6,4 +6,6 @@ use serde::{Deserialize, Serialize}; pub enum Ups { #[error("request_timeout", "Request timeout.")] RequestTimeout, + #[error("publish_failed", "Failed to publish message after retries")] + PublishFailed, } diff --git a/packages/common/universalpubsub/src/pubsub.rs b/packages/common/universalpubsub/src/pubsub.rs index fd24e41ff2..ac6ff27696 100644 --- a/packages/common/universalpubsub/src/pubsub.rs +++ b/packages/common/universalpubsub/src/pubsub.rs @@ -8,6 +8,8 @@ use tokio::sync::broadcast; use tokio::sync::{RwLock, oneshot}; use uuid::Uuid; +use rivet_util::backoff::Backoff; + use crate::chunking::{ChunkTracker, encode_chunk, split_payload_into_chunks}; use crate::driver::{PubSubDriverHandle, PublishOpts, SubscriberDriverHandle}; @@ -131,7 +133,8 @@ impl PubSub { break; } } else { - self.driver.publish(subject, &encoded).await?; + // Use backoff when publishing through the driver + self.publish_with_backoff(subject, &encoded).await?; } } Ok(()) @@ -174,7 +177,26 @@ impl PubSub { break; } } else { - self.driver.publish(subject, &encoded).await?; + // Use backoff when publishing through the driver + self.publish_with_backoff(subject, &encoded).await?; + } + } + Ok(()) + } + + async fn publish_with_backoff(&self, subject: &str, encoded: &[u8]) -> Result<()> { + let mut backoff = Backoff::default(); + loop { + match self.driver.publish(subject, encoded).await { + Result::Ok(_) => break, + Err(err) if !backoff.tick().await => { + tracing::info!(?err, "error publishing, cannot retry again"); + return Err(crate::errors::Ups::PublishFailed.build().into()); + } + Err(err) => { + tracing::info!(?err, "error publishing, retrying"); + // Continue retrying + } } } Ok(()) @@ -293,7 +315,10 @@ impl Subscriber { pub async fn next(&mut self) -> Result { loop { match self.driver.next().await? { - DriverOutput::Message { subject, payload } => { + DriverOutput::Message { + subject: _, + payload, + } => { // Process chunks let mut tracker = self.pubsub.chunk_tracker.lock().unwrap(); match tracker.process_chunk(&payload) { diff --git a/packages/common/universalpubsub/tests/reconnect.rs b/packages/common/universalpubsub/tests/reconnect.rs index fa98767e95..230110789b 100644 --- a/packages/common/universalpubsub/tests/reconnect.rs +++ b/packages/common/universalpubsub/tests/reconnect.rs @@ -43,7 +43,7 @@ async fn test_nats_driver_with_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); - test_reconnect_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker).await; } #[tokio::test] @@ -77,7 +77,7 @@ async fn test_nats_driver_without_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); - test_reconnect_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker).await; } #[tokio::test] @@ -100,7 +100,7 @@ async fn test_postgres_driver_with_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), true); - test_reconnect_inner(&pubsub, &docker).await; + test_all_inner(&pubsub, &docker).await; } #[tokio::test] @@ -123,7 +123,13 @@ async fn test_postgres_driver_without_memory_reconnect() { .unwrap(); let pubsub = PubSub::new_with_memory_optimization(Arc::new(driver), false); + test_all_inner(&pubsub, &docker).await; +} + +async fn test_all_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { test_reconnect_inner(&pubsub, &docker).await; + test_publish_while_stopped(&pubsub, &docker).await; + test_subscribe_while_stopped(&pubsub, &docker).await; } async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker::DockerRunConfig) { @@ -158,35 +164,14 @@ async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker:: tracing::info!("restarting docker container"); docker.restart().await.unwrap(); - // Give the service time to come back up - tokio::time::sleep(Duration::from_secs(3)).await; - tracing::info!("docker container restarted"); - // Test publish/receive message after restart + // + // This should retry under the hood, since the container will still be starting let message_after = b"message after restart"; - - // Retry logic for publish after restart since connection might need to reconnect - let mut retries = 0; - const MAX_RETRIES: u32 = 10; - loop { - match pubsub - .publish("test.reconnect", message_after, PublishOpts::broadcast()) - .await - { - Result::Ok(_) => { - tracing::info!("published message after restart"); - break; - } - Result::Err(e) if retries < MAX_RETRIES => { - retries += 1; - tracing::debug!(?e, retries, "failed to publish after restart, retrying"); - tokio::time::sleep(Duration::from_millis(500)).await; - } - Result::Err(e) => { - panic!("failed to publish after {} retries: {}", MAX_RETRIES, e); - } - } - } + pubsub + .publish("test.reconnect", message_after, PublishOpts::broadcast()) + .await + .unwrap(); pubsub.flush().await.unwrap(); @@ -215,3 +200,160 @@ async fn test_reconnect_inner(pubsub: &PubSub, docker: &rivet_test_deps_docker:: tracing::info!("reconnect test completed successfully"); } + +async fn test_publish_while_stopped( + pubsub: &PubSub, + docker: &rivet_test_deps_docker::DockerRunConfig, +) { + tracing::info!("testing publish while container stopped"); + + // 1. Subscribe + let mut subscriber = pubsub.subscribe("test.publish_stopped").await.unwrap(); + tracing::info!("opened subscription"); + + // 2. Stop container + tracing::info!("stopping docker container"); + docker.stop_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(2)).await; + + // 3. Publish while stopped (should queue/retry) + let message = b"message while stopped"; + let publish_handle = tokio::spawn({ + let pubsub = pubsub.clone(); + let message = message.to_vec(); + async move { + pubsub + .publish("test.publish_stopped", &message, PublishOpts::broadcast()) + .await + } + }); + + // 4. Start container + tokio::time::sleep(Duration::from_secs(3)).await; + tracing::info!("starting docker container"); + docker.start_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + // Wait for publish to complete + publish_handle.await.unwrap().unwrap(); + pubsub.flush().await.unwrap(); + + // 5. Receive message + tracing::info!("waiting for message"); + let receive_timeout = Duration::from_secs(5); + let receive_result = tokio::time::timeout(receive_timeout, subscriber.next()).await; + + match receive_result { + Result::Ok(Result::Ok(NextOutput::Message(msg))) => { + assert_eq!( + msg.payload, message, + "message published while stopped should be received" + ); + tracing::info!("received message published while stopped - reconnection successful"); + } + Result::Ok(Result::Ok(NextOutput::Unsubscribed)) => { + panic!("unexpected unsubscribe"); + } + Result::Ok(Result::Err(e)) => { + panic!("error receiving message: {}", e); + } + Result::Err(_) => { + panic!("timeout receiving message"); + } + } + + tracing::info!("publish while stopped test completed successfully"); +} + +async fn test_subscribe_while_stopped( + pubsub: &PubSub, + docker: &rivet_test_deps_docker::DockerRunConfig, +) { + tracing::info!("testing subscribe while container stopped"); + + // 1. Subscribe & test publish & unsubscribe + let mut subscriber = pubsub.subscribe("test.subscribe_stopped").await.unwrap(); + tracing::info!("opened initial subscription"); + + let test_message = b"test message"; + pubsub + .publish( + "test.subscribe_stopped", + test_message, + PublishOpts::broadcast(), + ) + .await + .unwrap(); + pubsub.flush().await.unwrap(); + + match subscriber.next().await.unwrap() { + NextOutput::Message(msg) => { + assert_eq!(msg.payload, test_message, "initial message should match"); + tracing::info!("received initial test message"); + } + NextOutput::Unsubscribed => { + panic!("unexpected unsubscribe"); + } + } + + drop(subscriber); // Drop to unsubscribe + tracing::info!("unsubscribed from initial subscription"); + + // 2. Stop container + tracing::info!("stopping docker container"); + docker.stop_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(2)).await; + + // 3. Subscribe while stopped + let subscribe_handle = tokio::spawn({ + let pubsub = pubsub.clone(); + async move { pubsub.subscribe("test.subscribe_stopped").await } + }); + + // 4. Start container + tokio::time::sleep(Duration::from_secs(3)).await; + tracing::info!("starting docker container"); + docker.start_container().await.unwrap(); + tokio::time::sleep(Duration::from_secs(5)).await; + + // Wait for subscription to complete + let mut new_subscriber = subscribe_handle.await.unwrap().unwrap(); + tracing::info!("new subscription established after reconnect"); + + // 5. Publish message + let final_message = b"message after reconnect"; + pubsub + .publish( + "test.subscribe_stopped", + final_message, + PublishOpts::broadcast(), + ) + .await + .unwrap(); + pubsub.flush().await.unwrap(); + + // 6. Receive + let receive_timeout = Duration::from_secs(10); + let receive_result = tokio::time::timeout(receive_timeout, new_subscriber.next()).await; + + match receive_result { + Result::Ok(Result::Ok(NextOutput::Message(msg))) => { + assert_eq!( + msg.payload, final_message, + "message after reconnect should match" + ); + tracing::info!("received message after reconnect - subscription successful"); + } + Result::Ok(Result::Ok(NextOutput::Unsubscribed)) => { + panic!("unexpected unsubscribe"); + } + Result::Ok(Result::Err(e)) => { + panic!("error receiving message: {}", e); + } + Result::Err(_) => { + panic!("timeout receiving message"); + } + } + + tracing::info!("subscribe while stopped test completed successfully"); +} From d4b39b34edd2b6e9d6703c602c51930520931e67 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Mon, 8 Sep 2025 20:42:18 -0700 Subject: [PATCH 4/8] ci: switch to depot --- .github/workflows/release.yaml | 14 +++++++------- .github/workflows/rust.yml | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 961fbcbfb1..ce35afb4d2 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -79,22 +79,22 @@ jobs: matrix: include: - platform: linux - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: x86_64-unknown-linux-musl binary_ext: "" arch: x86_64 - platform: windows - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: x86_64-pc-windows-gnu binary_ext: ".exe" arch: x86_64 - platform: macos - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: x86_64-apple-darwin binary_ext: "" arch: x86_64 - platform: macos - runner: ubuntu-latest + runner: depot-ubuntu-24.04 target: aarch64-apple-darwin binary_ext: "" arch: aarch64 @@ -155,10 +155,10 @@ jobs: include: # TODO(RVT-4479): Add back ARM builder once manifest generation fixed # - platform: linux/arm64 - # runner: ubuntu-latest + # runner: depot-ubuntu-24.04 # arch_suffix: -arm64 - platform: linux/x86_64 - runner: ubuntu-latest + runner: depot-ubuntu-24.04 # TODO: Replace with appropriate arch_suffix when needed # arch_suffix: -amd64 arch_suffix: '' @@ -246,4 +246,4 @@ jobs: ./scripts/release/main.ts --version "${{ github.event.inputs.version }}" --completeCi else ./scripts/release/main.ts --version "${{ github.event.inputs.version }}" --no-latest --completeCi - fi \ No newline at end of file + fi diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1671a3fc41..745f5ae9cd 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -43,7 +43,7 @@ jobs: # clippy: # name: Clippy - # runs-on: ubuntu-latest + # runs-on: depot-ubuntu-24.04 # steps: # - uses: actions/checkout@v4 @@ -59,7 +59,7 @@ jobs: check: name: Check - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -77,7 +77,7 @@ jobs: test: name: Test - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - uses: actions/checkout@v4 From e0afd30b73ad2137e587f0412a6f80c30e19281d Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 9 Sep 2025 09:41:07 -0700 Subject: [PATCH 5/8] chore(runner): handle async actor stop --- sdks/typescript/runner/src/mod.ts | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 7d26a37925..2623d63848 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -127,7 +127,7 @@ export class Runner { // The server will send a StopActor command if it wants to fully stop } - stopActor(actorId: string, generation?: number) { + async stopActor(actorId: string, generation?: number) { const actor = this.#removeActor(actorId, generation); if (!actor) return; @@ -136,11 +136,14 @@ export class Runner { this.#tunnel.unregisterActor(actor); } - this.#sendActorStateUpdate(actorId, actor.generation, "stopped"); - - this.#config.onActorStop(actorId, actor.generation).catch((err) => { + // If onActorStop times out, Pegboard will handle this timeout with ACTOR_STOP_THRESHOLD_DURATION_MS + try { + await this.#config.onActorStop(actorId, actor.generation); + } catch (err) { console.error(`Error in onActorStop for actor ${actorId}:`, err); - }); + } + + this.#sendActorStateUpdate(actorId, actor.generation, "stopped"); } #stopAllActors() { From c70a66726194a3488bdf36fdcbb3d3f988e9a4c5 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Fri, 12 Sep 2025 13:22:24 -0700 Subject: [PATCH 6/8] fix(epoxy): fix Any quorum type not reaching any node --- packages/services/epoxy/src/http_client.rs | 15 ++++++++++++--- packages/services/epoxy/src/utils.rs | 2 +- packages/services/epoxy/tests/proposal.rs | 1 - 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/packages/services/epoxy/src/http_client.rs b/packages/services/epoxy/src/http_client.rs index bfac730d8a..51864305c6 100644 --- a/packages/services/epoxy/src/http_client.rs +++ b/packages/services/epoxy/src/http_client.rs @@ -54,12 +54,21 @@ where ) .collect::>() .await; + tracing::info!(?quorum_size, len = ?responses.len(), ?quorum_type, "fanout quorum size"); + + // Choow how many successful responses we need before considering a success + let target_responses = match quorum_type { + // Only require 1 response + utils::QuorumType::Any => 1, + // Include all responses + utils::QuorumType::All => responses.len(), + // Subtract 1 from quorum size since we're not counting ourselves + utils::QuorumType::Fast | utils::QuorumType::Slow => quorum_size - 1, + }; // Collect responses until we reach quorum or all futures complete - // - // Subtract 1 from quorum size since we're not counting ourselves let mut successful_responses = Vec::new(); - while successful_responses.len() < quorum_size - 1 { + while successful_responses.len() < target_responses { if let Some(response) = responses.next().await { match response { std::result::Result::Ok(result) => match result { diff --git a/packages/services/epoxy/src/utils.rs b/packages/services/epoxy/src/utils.rs index d7a6cf44e7..b6896f1a08 100644 --- a/packages/services/epoxy/src/utils.rs +++ b/packages/services/epoxy/src/utils.rs @@ -2,7 +2,7 @@ use anyhow::*; use epoxy_protocol::protocol::{self, ReplicaId}; use universaldb::Transaction; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub enum QuorumType { Fast, Slow, diff --git a/packages/services/epoxy/tests/proposal.rs b/packages/services/epoxy/tests/proposal.rs index f14ab6cd98..f297dfbc40 100644 --- a/packages/services/epoxy/tests/proposal.rs +++ b/packages/services/epoxy/tests/proposal.rs @@ -4,7 +4,6 @@ use common::THREE_REPLICAS; use epoxy::ops::propose::ProposalResult; use epoxy_protocol::protocol; use gas::prelude::*; -use rivet_acl::{Verifier, config::AclConfig}; use rivet_api_builder::{ApiCtx, GlobalApiCtx}; use rivet_util::Id; From 9549b26deb5716f56661b6ef7af2cca1d829e8a1 Mon Sep 17 00:00:00 2001 From: Kacper Wojciechowski <39823706+jog1t@users.noreply.github.com> Date: Thu, 11 Sep 2025 23:28:49 +0200 Subject: [PATCH 7/8] fix(engine/fe): remove addresses --- .../routes/_layout/ns.$namespace/runners.tsx | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/frontend/src/routes/_layout/ns.$namespace/runners.tsx b/frontend/src/routes/_layout/ns.$namespace/runners.tsx index 21474beec0..9969231100 100644 --- a/frontend/src/routes/_layout/ns.$namespace/runners.tsx +++ b/frontend/src/routes/_layout/ns.$namespace/runners.tsx @@ -71,7 +71,6 @@ function RouteComponent() { ID Name - HTTP Slots Last ping Created @@ -155,19 +154,11 @@ function RowSkeleton() { - - - ); } -const MAX_TO_SHOW = 2; - function Row(runner: Rivet.Runner) { - const [isExpanded, setExpanded] = useState(false); - const addresses = Object.values(runner.addressesHttp); - return ( @@ -186,36 +177,6 @@ function Row(runner: Rivet.Runner) { - -
- {addresses - .slice(0, isExpanded ? addresses.length : MAX_TO_SHOW) - .map((http) => { - const address = `${http.hostname}:${http.port}`; - return ( - - {address} - - ); - })} - - {addresses.length > MAX_TO_SHOW && !isExpanded ? ( - - ) : null} -
-
- {runner.remainingSlots}/{runner.totalSlots} From 5ddcc8537f73ebac4e53f3a61182f16c54f2fc0f Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Fri, 5 Sep 2025 13:09:49 -0700 Subject: [PATCH 8/8] feat(pegboard): outbound runners --- Cargo.lock | 143 +++++-- Cargo.toml | 16 +- docker/dev/docker-compose.yml | 1 + out/openapi.json | 70 +++- packages/common/api-builder/Cargo.toml | 1 + packages/common/api-builder/src/wrappers.rs | 3 +- .../common/gasoline/core/src/utils/tags.rs | 10 + packages/common/pools/src/reqwest.rs | 7 + packages/common/types/Cargo.toml | 2 +- packages/common/types/src/runners.rs | 2 +- packages/common/udb-util/src/keys.rs | 4 +- packages/core/api-peer/src/namespaces.rs | 16 +- packages/core/guard/server/Cargo.toml | 2 +- packages/core/pegboard-outbound/Cargo.toml | 20 + packages/core/pegboard-outbound/src/lib.rs | 288 ++++++++++++++ packages/infra/engine/Cargo.toml | 9 +- packages/infra/engine/src/run_config.rs | 5 + packages/services/namespace/Cargo.toml | 6 +- packages/services/namespace/src/keys.rs | 50 +++ .../services/namespace/src/ops/get_global.rs | 58 ++- .../services/namespace/src/ops/get_local.rs | 39 +- packages/services/namespace/src/types.rs | 54 +++ .../namespace/src/workflows/namespace.rs | 57 +-- packages/services/pegboard/Cargo.toml | 2 +- .../services/pegboard/src/keys/datacenter.rs | 249 ------------ packages/services/pegboard/src/keys/mod.rs | 1 - packages/services/pegboard/src/keys/ns.rs | 357 +++++++++++++++++- packages/services/pegboard/src/keys/runner.rs | 16 +- packages/services/pegboard/src/lib.rs | 1 + packages/services/pegboard/src/messages.rs | 4 + .../services/pegboard/src/ops/actor/create.rs | 6 +- .../pegboard/src/ops/actor/get_for_key.rs | 4 +- .../pegboard/src/ops/actor/list_names.rs | 2 +- .../services/pegboard/src/ops/runner/get.rs | 2 +- .../src/ops/runner/update_alloc_idx.rs | 8 +- .../src/workflows/actor/actor_keys.rs | 2 +- .../pegboard/src/workflows/actor/destroy.rs | 21 +- .../pegboard/src/workflows/actor/mod.rs | 46 ++- .../pegboard/src/workflows/actor/runtime.rs | 95 +++-- .../pegboard/src/workflows/actor/setup.rs | 22 +- .../services/pegboard/src/workflows/runner.rs | 32 +- pnpm-lock.yaml | 3 + sdks/rust/{key-data => data}/Cargo.toml | 2 +- sdks/rust/{key-data => data}/build.rs | 2 +- sdks/rust/{key-data => data}/src/converted.rs | 8 +- sdks/rust/{key-data => data}/src/generated.rs | 0 sdks/rust/{key-data => data}/src/lib.rs | 2 +- sdks/rust/{key-data => data}/src/versioned.rs | 40 +- .../data/namespace.runner_kind.v1.bare | 14 + .../pegboard.namespace.actor_by_key.v1.bare | 0 .../pegboard.namespace.actor_name.v1.bare | 0 ...gboard.namespace.runner_alloc_idx.v1.bare} | 0 .../pegboard.namespace.runner_by_key.v1.bare | 0 .../pegboard.runner.address.v1.bare | 0 .../pegboard.runner.metadata.v1.bare | 0 sdks/typescript/runner/src/mod.ts | 7 +- sdks/typescript/test-runner/package.json | 3 +- sdks/typescript/test-runner/src/main.ts | 240 +++++++----- 58 files changed, 1432 insertions(+), 622 deletions(-) create mode 100644 packages/core/pegboard-outbound/Cargo.toml create mode 100644 packages/core/pegboard-outbound/src/lib.rs delete mode 100644 packages/services/pegboard/src/keys/datacenter.rs create mode 100644 packages/services/pegboard/src/messages.rs rename sdks/rust/{key-data => data}/Cargo.toml (95%) rename sdks/rust/{key-data => data}/build.rs (99%) rename sdks/rust/{key-data => data}/src/converted.rs (94%) rename sdks/rust/{key-data => data}/src/generated.rs (100%) rename sdks/rust/{key-data => data}/src/lib.rs (84%) rename sdks/rust/{key-data => data}/src/versioned.rs (83%) create mode 100644 sdks/schemas/data/namespace.runner_kind.v1.bare rename sdks/schemas/{key-data => data}/pegboard.namespace.actor_by_key.v1.bare (100%) rename sdks/schemas/{key-data => data}/pegboard.namespace.actor_name.v1.bare (100%) rename sdks/schemas/{key-data/pegboard.datacenter.runner_alloc_idx.v1.bare => data/pegboard.namespace.runner_alloc_idx.v1.bare} (100%) rename sdks/schemas/{key-data => data}/pegboard.namespace.runner_by_key.v1.bare (100%) rename sdks/schemas/{key-data => data}/pegboard.runner.address.v1.bare (100%) rename sdks/schemas/{key-data => data}/pegboard.runner.metadata.v1.bare (100%) diff --git a/Cargo.lock b/Cargo.lock index 44bd7cd6e0..f1f2a7a52d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -377,6 +377,31 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d" +dependencies = [ + "axum 0.8.4", + "axum-core 0.5.2", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "serde", + "serde_html_form", + "serde_path_to_error", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-test" version = "17.3.0" @@ -1454,6 +1479,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -2737,6 +2773,7 @@ dependencies = [ "gasoline", "rivet-api-builder", "rivet-api-util", + "rivet-data", "rivet-error", "rivet-util", "serde", @@ -2744,6 +2781,7 @@ dependencies = [ "udb-util", "universaldb", "utoipa", + "versioned-data-util", ] [[package]] @@ -3217,8 +3255,8 @@ dependencies = [ "rivet-api-client", "rivet-api-types", "rivet-api-util", + "rivet-data", "rivet-error", - "rivet-key-data", "rivet-metrics", "rivet-runner-protocol", "rivet-types", @@ -3277,6 +3315,23 @@ dependencies = [ "versioned-data-util", ] +[[package]] +name = "pegboard-outbound" +version = "0.0.1" +dependencies = [ + "anyhow", + "epoxy", + "gasoline", + "namespace", + "pegboard", + "reqwest-eventsource", + "rivet-config", + "rivet-runner-protocol", + "tracing", + "udb-util", + "universaldb", +] + [[package]] name = "pegboard-runner-ws" version = "0.0.1" @@ -3960,16 +4015,34 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls 0.26.2", + "tokio-util", "tower 0.5.2", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots 1.0.2", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror 1.0.69", +] + [[package]] name = "reserve-port" version = "2.3.0" @@ -3999,6 +4072,7 @@ version = "0.0.1" dependencies = [ "anyhow", "axum 0.8.4", + "axum-extra", "axum-test", "chrono", "gasoline", @@ -4204,6 +4278,24 @@ dependencies = [ "uuid", ] +[[package]] +name = "rivet-data" +version = "0.0.1" +dependencies = [ + "anyhow", + "bare_gen", + "gasoline", + "indoc", + "prettyplease", + "rivet-runner-protocol", + "rivet-util", + "serde", + "serde_bare", + "serde_json", + "syn 2.0.104", + "versioned-data-util", +] + [[package]] name = "rivet-dump-openapi" version = "0.0.1" @@ -4231,6 +4323,7 @@ dependencies = [ "lz4_flex", "namespace", "pegboard", + "pegboard-outbound", "pegboard-runner-ws", "portpicker", "rand 0.8.5", @@ -4327,9 +4420,9 @@ dependencies = [ "rivet-api-public", "rivet-cache", "rivet-config", + "rivet-data", "rivet-error", "rivet-guard-core", - "rivet-key-data", "rivet-logs", "rivet-metrics", "rivet-pools", @@ -4392,24 +4485,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "rivet-key-data" -version = "0.0.1" -dependencies = [ - "anyhow", - "bare_gen", - "gasoline", - "indoc", - "prettyplease", - "rivet-runner-protocol", - "rivet-util", - "serde", - "serde_bare", - "serde_json", - "syn 2.0.104", - "versioned-data-util", -] - [[package]] name = "rivet-logs" version = "0.0.1" @@ -4615,7 +4690,7 @@ dependencies = [ "anyhow", "gasoline", "rivet-api-builder", - "rivet-key-data", + "rivet-data", "rivet-runner-protocol", "rivet-util", "serde", @@ -5215,6 +5290,19 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap 2.10.0", + "itoa 1.0.15", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.141" @@ -6618,6 +6706,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" diff --git a/Cargo.toml b/Cargo.toml index 08114056cb..cc32db3c24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] resolver = "2" -members = ["packages/common/api-builder","packages/common/api-client","packages/common/api-types","packages/common/api-util","packages/common/cache/build","packages/common/cache/result","packages/common/clickhouse-inserter","packages/common/clickhouse-user-query","packages/common/config","packages/common/env","packages/common/error/core","packages/common/error/macros","packages/common/gasoline/core","packages/common/gasoline/macros","packages/common/logs","packages/common/metrics","packages/common/pools","packages/common/runtime","packages/common/service-manager","packages/common/telemetry","packages/common/test-deps","packages/common/test-deps-docker","packages/common/types","packages/common/udb-util","packages/common/universaldb","packages/common/universalpubsub","packages/common/util/core","packages/common/util/id","packages/common/versioned-data-util","packages/core/actor-kv","packages/core/api-peer","packages/core/api-public","packages/core/bootstrap","packages/core/dump-openapi","packages/core/guard/core","packages/core/guard/server","packages/core/pegboard-gateway","packages/core/pegboard-runner-ws","packages/core/pegboard-tunnel","packages/core/workflow-worker","packages/infra/engine","packages/services/epoxy","packages/services/namespace","packages/services/pegboard","sdks/rust/api-full","sdks/rust/bare_gen","sdks/rust/epoxy-protocol","sdks/rust/key-data","sdks/rust/runner-protocol","sdks/rust/tunnel-protocol","sdks/rust/ups-protocol"] +members = ["packages/common/api-builder","packages/common/api-client","packages/common/api-types","packages/common/api-util","packages/common/cache/build","packages/common/cache/result","packages/common/clickhouse-inserter","packages/common/clickhouse-user-query","packages/common/config","packages/common/env","packages/common/error/core","packages/common/error/macros","packages/common/gasoline/core","packages/common/gasoline/macros","packages/common/logs","packages/common/metrics","packages/common/pools","packages/common/runtime","packages/common/service-manager","packages/common/telemetry","packages/common/test-deps","packages/common/test-deps-docker","packages/common/types","packages/common/udb-util","packages/common/universaldb","packages/common/universalpubsub","packages/common/util/core","packages/common/util/id","packages/common/versioned-data-util","packages/core/actor-kv","packages/core/api-peer","packages/core/api-public","packages/core/bootstrap","packages/core/dump-openapi","packages/core/guard/core","packages/core/guard/server","packages/core/pegboard-gateway","packages/core/pegboard-outbound","packages/core/pegboard-runner-ws","packages/core/pegboard-tunnel","packages/core/workflow-worker","packages/infra/engine","packages/services/epoxy","packages/services/namespace","packages/services/pegboard","sdks/rust/api-full","sdks/rust/bare_gen","sdks/rust/data","sdks/rust/epoxy-protocol","sdks/rust/runner-protocol","sdks/rust/tunnel-protocol","sdks/rust/ups-protocol"] [workspace.package] version = "0.0.1" @@ -79,6 +79,7 @@ tracing-core = "0.1" tracing-opentelemetry = "0.29" tracing-slog = "0.2" vergen = "9.0.4" +reqwest-eventsource = "0.6.0" [workspace.dependencies.sentry] version = "0.37.0" @@ -118,6 +119,10 @@ features = ["uuid"] version = "0.8" features = ["http2"] +[workspace.dependencies.axum-extra] +version = "0.10.1" +features = ["query"] + [workspace.dependencies.tower-http] version = "0.6" features = ["cors","trace"] @@ -359,6 +364,9 @@ path = "packages/core/guard/server" [workspace.dependencies.pegboard-gateway] path = "packages/core/pegboard-gateway" +[workspace.dependencies.pegboard-outbound] +path = "packages/core/pegboard-outbound" + [workspace.dependencies.pegboard-runner-ws] path = "packages/core/pegboard-runner-ws" @@ -386,12 +394,12 @@ path = "sdks/rust/api-full" [workspace.dependencies.bare_gen] path = "sdks/rust/bare_gen" +[workspace.dependencies.rivet-data] +path = "sdks/rust/data" + [workspace.dependencies.epoxy-protocol] path = "sdks/rust/epoxy-protocol" -[workspace.dependencies.rivet-key-data] -path = "sdks/rust/key-data" - [workspace.dependencies.rivet-runner-protocol] path = "sdks/rust/runner-protocol" diff --git a/docker/dev/docker-compose.yml b/docker/dev/docker-compose.yml index 6803272ebe..64de5f178d 100644 --- a/docker/dev/docker-compose.yml +++ b/docker/dev/docker-compose.yml @@ -187,6 +187,7 @@ services: environment: - RIVET_ENDPOINT=http://rivet-engine:6420 - RUNNER_HOST=runner + # - NO_AUTOSTART=1 stop_grace_period: 4s ports: - '5050:5050' diff --git a/out/openapi.json b/out/openapi.json index 87571399cc..ff97997a91 100644 --- a/out/openapi.json +++ b/out/openapi.json @@ -465,6 +465,17 @@ "schema": { "type": "string" } + }, + { + "name": "namespace_id", + "in": "query", + "required": true, + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RivetId" + } + } } ], "responses": { @@ -1042,7 +1053,8 @@ "namespace_id", "name", "display_name", - "create_ts" + "create_ts", + "runner_kind" ], "properties": { "create_ts": { @@ -1057,6 +1069,9 @@ }, "namespace_id": { "$ref": "#/components/schemas/RivetId" + }, + "runner_kind": { + "$ref": "#/components/schemas/RunnerKind" } } }, @@ -1238,6 +1253,59 @@ }, "additionalProperties": false }, + "RunnerKind": { + "oneOf": [ + { + "type": "object", + "required": [ + "outbound" + ], + "properties": { + "outbound": { + "type": "object", + "required": [ + "url", + "slots_per_runner", + "min_runners", + "max_runners", + "runners_margin" + ], + "properties": { + "max_runners": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "min_runners": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "runners_margin": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "slots_per_runner": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "url": { + "type": "string" + } + } + } + } + }, + { + "type": "string", + "enum": [ + "custom" + ] + } + ] + }, "RunnersGetResponse": { "type": "object", "required": [ diff --git a/packages/common/api-builder/Cargo.toml b/packages/common/api-builder/Cargo.toml index f9ce906458..5655703737 100644 --- a/packages/common/api-builder/Cargo.toml +++ b/packages/common/api-builder/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true [dependencies] anyhow.workspace = true axum.workspace = true +axum-extra.workspace = true gas.workspace = true chrono.workspace = true hyper = { workspace = true, features = ["full"] } diff --git a/packages/common/api-builder/src/wrappers.rs b/packages/common/api-builder/src/wrappers.rs index e6d8193c2d..926b6ece48 100644 --- a/packages/common/api-builder/src/wrappers.rs +++ b/packages/common/api-builder/src/wrappers.rs @@ -1,13 +1,14 @@ use anyhow::Result; use axum::{ body::Bytes, - extract::{Extension, Path, Query}, + extract::{Extension, Path}, response::{IntoResponse, Json}, routing::{ delete as axum_delete, get as axum_get, patch as axum_patch, post as axum_post, put as axum_put, }, }; +use axum_extra::extract::Query; use serde::{Serialize, de::DeserializeOwned}; use std::future::Future; diff --git a/packages/common/gasoline/core/src/utils/tags.rs b/packages/common/gasoline/core/src/utils/tags.rs index 967c59d10f..92326b65bd 100644 --- a/packages/common/gasoline/core/src/utils/tags.rs +++ b/packages/common/gasoline/core/src/utils/tags.rs @@ -61,6 +61,16 @@ impl AsTags for serde_json::Value { } } +impl AsTags for () { + fn as_tags(&self) -> WorkflowResult { + Ok(serde_json::Value::Object(serde_json::Map::new())) + } + + fn as_cjson_tags(&self) -> WorkflowResult { + Ok(String::new()) + } +} + impl AsTags for &T { fn as_tags(&self) -> WorkflowResult { (*self).as_tags() diff --git a/packages/common/pools/src/reqwest.rs b/packages/common/pools/src/reqwest.rs index b3044041af..78f1f2e7cb 100644 --- a/packages/common/pools/src/reqwest.rs +++ b/packages/common/pools/src/reqwest.rs @@ -13,3 +13,10 @@ pub async fn client() -> Result { .await .cloned() } + +pub async fn client_no_timeout() -> Result { + CLIENT + .get_or_try_init(|| async { Client::builder().build() }) + .await + .cloned() +} diff --git a/packages/common/types/Cargo.toml b/packages/common/types/Cargo.toml index 6e429cbacf..9ae35c64aa 100644 --- a/packages/common/types/Cargo.toml +++ b/packages/common/types/Cargo.toml @@ -10,7 +10,7 @@ anyhow.workspace = true gas.workspace = true rivet-api-builder.workspace = true rivet-runner-protocol.workspace = true -rivet-key-data.workspace = true +rivet-data.workspace = true rivet-util.workspace = true serde.workspace = true utoipa.workspace = true diff --git a/packages/common/types/src/runners.rs b/packages/common/types/src/runners.rs index f56db3f670..d70ed49eda 100644 --- a/packages/common/types/src/runners.rs +++ b/packages/common/types/src/runners.rs @@ -1,5 +1,5 @@ use gas::prelude::*; -use rivet_key_data::generated::pegboard_runner_address_v1; +use rivet_data::generated::pegboard_runner_address_v1; use rivet_runner_protocol::protocol; use rivet_util::Id; use serde::{Deserialize, Serialize}; diff --git a/packages/common/udb-util/src/keys.rs b/packages/common/udb-util/src/keys.rs index 397c727c28..c0406fe67f 100644 --- a/packages/common/udb-util/src/keys.rs +++ b/packages/common/udb-util/src/keys.rs @@ -59,7 +59,7 @@ define_keys! { (31, DBS, "dbs"), (32, ACTOR, "actor"), (33, BY_NAME, "by_name"), - (34, DATACENTER, "datacenter"), + // 34 (35, REMAINING_MEMORY, "remaining_memory"), (36, REMAINING_CPU, "remaining_cpu"), (37, TOTAL_MEMORY, "total_memory"), @@ -119,4 +119,6 @@ define_keys! { (91, METRIC, "metric"), (92, CURRENT_BALLOT, "current_ballot"), (93, INSTANCE_BALLOT, "instance_ballot"), + (94, OUTBOUND, "outbound"), + (95, DESIRED_SLOTS, "desired_slots"), } diff --git a/packages/core/api-peer/src/namespaces.rs b/packages/core/api-peer/src/namespaces.rs index 196764252f..36b0a3c9fc 100644 --- a/packages/core/api-peer/src/namespaces.rs +++ b/packages/core/api-peer/src/namespaces.rs @@ -73,6 +73,7 @@ pub struct ListQuery { pub limit: Option, pub cursor: Option, pub name: Option, + pub namespace_id: Vec, } #[derive(Serialize, Deserialize, ToSchema)] @@ -85,7 +86,7 @@ pub struct ListResponse { #[utoipa::path( get, - operation_id = "actors_list", + operation_id = "namespaces_list", path = "/namespaces", params(ListQuery), responses( @@ -105,6 +106,17 @@ pub async fn list(ctx: ApiCtx, _path: (), query: ListQuery) -> Result, + shutdown_tx: oneshot::Sender<()>, + draining: Arc, +} + +#[tracing::instrument(skip_all)] +pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> Result<()> { + let cache = rivet_cache::CacheInner::from_env(&config, pools.clone())?; + let ctx = StandaloneCtx::new( + db::DatabaseKv::from_pools(pools.clone()).await?, + config.clone(), + pools, + cache, + "pegboard-outbound", + Id::new_v1(config.dc_label()), + Id::new_v1(config.dc_label()), + )?; + + let mut sub = ctx + .subscribe::(()) + .await?; + let mut outbound_connections = HashMap::new(); + + loop { + tick(&ctx, &mut outbound_connections).await?; + + sub.next().await?; + } +} + +async fn tick( + ctx: &StandaloneCtx, + outbound_connections: &mut HashMap<(Id, String), Vec>, +) -> Result<()> { + let outbound_data = ctx + .udb()? + .run(|tx, _mc| async move { + let txs = tx.subspace(keys::subspace()); + let outbound_desired_subspace = + txs.subspace(&keys::ns::OutboundDesiredSlotsKey::subspace()); + + txs.get_ranges_keyvalues( + udb::RangeOption { + mode: StreamingMode::WantAll, + ..(&outbound_desired_subspace).into() + }, + // NOTE: This is a snapshot to prevent conflict with updates to this subspace + SNAPSHOT, + ) + .map(|res| match res { + Ok(entry) => { + let (key, desired_slots) = + txs.read_entry::(&entry)?; + + Ok((key.namespace_id, key.runner_name_selector, desired_slots)) + } + Err(err) => Err(err.into()), + }) + .try_collect::>() + .await + + // outbound/{ns_id}/{runner_name_selector}/desired_slots + }) + .await?; + + let mut namespace_ids = outbound_data + .iter() + .map(|(ns_id, _, _)| *ns_id) + .collect::>(); + namespace_ids.dedup(); + + let namespaces = ctx + .op(namespace::ops::get_global::Input { namespace_ids }) + .await?; + + for (ns_id, runner_name_selector, desired_slots) in &outbound_data { + let namespace = namespaces + .iter() + .find(|ns| ns.namespace_id == *ns_id) + .context("ns not found")?; + + let RunnerKind::Outbound { + url, + slots_per_runner, + min_runners, + max_runners, + runners_margin, + } = &namespace.runner_kind + else { + tracing::warn!( + ?ns_id, + "this namespace should not be in the outbound subspace (wrong runner kind)" + ); + continue; + }; + + let curr = outbound_connections + .entry((*ns_id, runner_name_selector.clone())) + .or_insert_with(Vec::new); + + // Remove finished and draining connections from list + curr.retain(|conn| !conn.handle.is_finished() && !conn.draining.load(Ordering::SeqCst)); + + let desired_count = (desired_slots + .div_ceil(*slots_per_runner) + .max(*min_runners) + .min(*max_runners) + + runners_margin) + .try_into()?; + + // Calculate diff + let drain_count = curr.len().saturating_sub(desired_count); + let start_count = desired_count.saturating_sub(curr.len()); + + if drain_count != 0 { + // TODO: Implement smart logic of draining runners with the lowest allocated actors + let draining_connections = curr.split_off(desired_count); + + for conn in draining_connections { + if conn.shutdown_tx.send(()).is_err() { + tracing::warn!( + "outbound connection shutdown channel dropped, likely already stopped" + ); + } + } + } + + let starting_connections = + std::iter::repeat_with(|| spawn_connection(ctx.clone(), url.clone())).take(start_count); + curr.extend(starting_connections); + } + + // Remove entries that aren't returned from udb + outbound_connections.retain(|(ns_id, runner_name_selector), _| { + outbound_data + .iter() + .any(|(ns_id2, runner_name_selector2, _)| { + ns_id == ns_id2 && runner_name_selector == runner_name_selector2 + }) + }); + + Ok(()) +} + +fn spawn_connection(ctx: StandaloneCtx, url: String) -> OutboundConnection { + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let draining = Arc::new(AtomicBool::new(false)); + + let draining2 = draining.clone(); + let handle = tokio::spawn(async move { + if let Err(err) = outbound_handler(&ctx, url, shutdown_rx, draining2).await { + tracing::error!(?err, "outbound req failed"); + + // TODO: Add backoff + tokio::time::sleep(Duration::from_secs(1)).await; + + // On error, bump the autoscaler loop again + let _ = ctx + .msg(pegboard::messages::BumpOutboundAutoscaler {}) + .send() + .await; + } + }); + + OutboundConnection { + handle, + shutdown_tx, + draining, + } +} + +async fn outbound_handler( + ctx: &StandaloneCtx, + url: String, + shutdown_rx: oneshot::Receiver<()>, + draining: Arc, +) -> Result<()> { + let client = rivet_pools::reqwest::client_no_timeout().await?; + let mut es = sse::EventSource::new(client.get(url))?; + let mut runner_id = None; + + let stream_handler = async { + while let Some(event) = es.next().await { + match event { + Ok(sse::Event::Open) => {} + Ok(sse::Event::Message(msg)) => { + tracing::debug!(%msg.data, "received outbound req message"); + + if runner_id.is_none() { + runner_id = Some(Id::parse(&msg.data)?); + } + } + Err(sse::Error::StreamEnded) => { + tracing::debug!("outbound req stopped early"); + + return Ok(()); + } + Err(err) => return Err(err.into()), + } + } + + anyhow::Ok(()) + }; + + tokio::select! { + res = stream_handler => return res.map_err(Into::into), + _ = tokio::time::sleep(OUTBOUND_REQUEST_LIFESPAN) => {} + _ = shutdown_rx => {} + } + + draining.store(true, Ordering::SeqCst); + + ctx.msg(pegboard::messages::BumpOutboundAutoscaler {}) + .send() + .await?; + + if let Some(runner_id) = runner_id { + stop_runner(ctx, runner_id).await?; + } + + // Continue waiting on req while draining + while let Some(event) = es.next().await { + match event { + Ok(sse::Event::Open) => {} + Ok(sse::Event::Message(msg)) => { + tracing::debug!(%msg.data, "received outbound req message"); + + // If runner_id is none at this point it means we did not send the stopping signal yet, so + // send it now + if runner_id.is_none() { + stop_runner(ctx, Id::parse(&msg.data)?).await?; + } + } + Err(sse::Error::StreamEnded) => break, + Err(err) => return Err(err.into()), + } + } + + tracing::info!("outbound req stopped"); + + Ok(()) +} + +async fn stop_runner(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> { + let res = ctx + .signal(protocol::ToServer::Stopping) + .to_workflow::() + .tag("runner_id", runner_id) + .send() + .await; + + if let Some(WorkflowError::WorkflowNotFound) = res + .as_ref() + .err() + .and_then(|x| x.chain().find_map(|x| x.downcast_ref::())) + { + tracing::warn!( + ?runner_id, + "runner workflow not found, likely already stopped" + ); + } else { + res?; + } + + Ok(()) +} diff --git a/packages/infra/engine/Cargo.toml b/packages/infra/engine/Cargo.toml index 2556c213f3..975ea3d788 100644 --- a/packages/infra/engine/Cargo.toml +++ b/packages/infra/engine/Cargo.toml @@ -11,16 +11,15 @@ path = "src/main.rs" [dependencies] anyhow.workspace = true -gas.workspace = true chrono.workspace = true clap.workspace = true colored_json.workspace = true -udb-util.workspace = true -universaldb.workspace = true futures-util.workspace = true +gas.workspace = true hex.workspace = true include_dir.workspace = true lz4_flex.workspace = true +pegboard-outbound.workspace = true pegboard-runner-ws.workspace = true reqwest.workspace = true rivet-api-peer.workspace = true @@ -38,15 +37,17 @@ rivet-term.workspace = true rivet-util.workspace = true rivet-workflow-worker.workspace = true rustyline.workspace = true -serde.workspace = true serde_json.workspace = true serde_yaml.workspace = true +serde.workspace = true strum.workspace = true tabled.workspace = true tempfile.workspace = true thiserror.workspace = true tokio.workspace = true tracing.workspace = true +udb-util.workspace = true +universaldb.workspace = true url.workspace = true uuid.workspace = true diff --git a/packages/infra/engine/src/run_config.rs b/packages/infra/engine/src/run_config.rs index ee683efc15..890b12842b 100644 --- a/packages/infra/engine/src/run_config.rs +++ b/packages/infra/engine/src/run_config.rs @@ -25,6 +25,11 @@ pub fn config(_rivet_config: rivet_config::Config) -> Result { Service::new("bootstrap", ServiceKind::Oneshot, |config, pools| { Box::pin(rivet_bootstrap::start(config, pools)) }), + Service::new( + "pegboard_outbound", + ServiceKind::Standalone, + |config, pools| Box::pin(pegboard_outbound::start(config, pools)), + ), ]; Ok(RunConfigData { services }) diff --git a/packages/services/namespace/Cargo.toml b/packages/services/namespace/Cargo.toml index fef88ecd1c..32d46a6735 100644 --- a/packages/services/namespace/Cargo.toml +++ b/packages/services/namespace/Cargo.toml @@ -8,12 +8,14 @@ edition.workspace = true [dependencies] anyhow.workspace = true gas.workspace = true -udb-util.workspace = true -universaldb.workspace = true rivet-api-builder.workspace = true rivet-api-util.workspace = true +rivet-data.workspace = true rivet-error.workspace = true rivet-util.workspace = true serde.workspace = true tracing.workspace = true +udb-util.workspace = true +universaldb.workspace = true utoipa.workspace = true +versioned-data-util.workspace = true diff --git a/packages/services/namespace/src/keys.rs b/packages/services/namespace/src/keys.rs index 3803feb6f6..c4e13afde4 100644 --- a/packages/services/namespace/src/keys.rs +++ b/packages/services/namespace/src/keys.rs @@ -3,6 +3,7 @@ use std::result::Result::Ok; use anyhow::*; use gas::prelude::*; use udb_util::prelude::*; +use versioned_data_util::OwnedVersionedData; pub fn subspace() -> udb_util::Subspace { udb_util::Subspace::new(&(RIVET, NAMESPACE)) @@ -144,6 +145,55 @@ impl<'de> TupleUnpack<'de> for CreateTsKey { } } +#[derive(Debug)] +pub struct RunnerKindKey { + namespace_id: Id, +} + +impl RunnerKindKey { + pub fn new(namespace_id: Id) -> Self { + RunnerKindKey { namespace_id } + } +} + +impl FormalKey for RunnerKindKey { + type Value = crate::types::RunnerKind; + + fn deserialize(&self, raw: &[u8]) -> Result { + Ok( + rivet_data::versioned::NamespaceRunnerKind::deserialize_with_embedded_version(raw)? + .into(), + ) + } + + fn serialize(&self, value: Self::Value) -> Result> { + rivet_data::versioned::NamespaceRunnerKind::latest(value.into()) + .serialize_with_embedded_version( + rivet_data::PEGBOARD_NAMESPACE_RUNNER_ALLOC_IDX_VERSION, + ) + } +} + +impl TuplePack for RunnerKindKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = (DATA, self.namespace_id, CREATE_TS); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for RunnerKindKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let (input, (_, namespace_id, _)) = <(usize, Id, usize)>::unpack(input, tuple_depth)?; + let v = RunnerKindKey { namespace_id }; + + Ok((input, v)) + } +} + #[derive(Debug)] pub struct ByNameKey { name: String, diff --git a/packages/services/namespace/src/ops/get_global.rs b/packages/services/namespace/src/ops/get_global.rs index 5dc5a58b08..a62eeda288 100644 --- a/packages/services/namespace/src/ops/get_global.rs +++ b/packages/services/namespace/src/ops/get_global.rs @@ -4,20 +4,19 @@ use crate::types::Namespace; #[derive(Debug)] pub struct Input { - // TODO: Accept vec - pub namespace_id: Id, + pub namespace_ids: Vec, } #[operation] -pub async fn namespace_get_global(ctx: &OperationCtx, input: &Input) -> Result> { +pub async fn namespace_get_global(ctx: &OperationCtx, input: &Input) -> Result> { if ctx.config().is_leader() { let namespaces_res = ctx .op(crate::ops::get_local::Input { - namespace_ids: vec![input.namespace_id], + namespace_ids: input.namespace_ids.clone(), }) .await?; - Ok(namespaces_res.namespaces.into_iter().next()) + Ok(namespaces_res.namespaces) } else { let leader_dc = ctx.config().leader_dc()?; let client = rivet_pools::reqwest::client().await?; @@ -25,51 +24,42 @@ pub async fn namespace_get_global(ctx: &OperationCtx, input: &Input) -> Result>(), + ) + .send() + .await?; - let res = rivet_api_util::parse_response::(res).await; + let res = rivet_api_util::parse_response::(res).await?; - let res = match res { - Ok(res) => Ok(Some(res.namespace)), - Err(err) => { - // Explicitly handle namespace not found error - if let Some(error) = err.chain().find_map(|x| { - x.downcast_ref::() - }) { - if error.1.group == "namespace" && error.1.code == "not_found" { - Ok(None) - } else { - Err(err) - } - } else { - Err(err) - } - } - }; - - cache.resolve(&key, res?); + for ns in res.namespaces { + let namespace_id = ns.namespace_id; + cache.resolve(&&namespace_id, ns); + } Ok(cache) } } }) .await - .map(|x| x.flatten()) } } // TODO: Cyclical dependency with api_peer #[derive(Deserialize)] -struct GetResponse { - namespace: Namespace, +struct ListResponse { + namespaces: Vec, } diff --git a/packages/services/namespace/src/ops/get_local.rs b/packages/services/namespace/src/ops/get_local.rs index 913579092f..ed6663d589 100644 --- a/packages/services/namespace/src/ops/get_local.rs +++ b/packages/services/namespace/src/ops/get_local.rs @@ -1,6 +1,6 @@ use futures_util::{StreamExt, TryStreamExt}; use gas::prelude::*; -use udb_util::{FormalKey, SERIALIZABLE}; +use udb_util::{SERIALIZABLE, TxnExt}; use universaldb as udb; use crate::{errors, keys, types::Namespace}; @@ -45,39 +45,40 @@ pub(crate) async fn get_inner( namespace_id: Id, tx: &udb::RetryableTransaction, ) -> std::result::Result, udb::FdbBindingError> { + let txs = tx.subspace(keys::subspace()); + let name_key = keys::NameKey::new(namespace_id); let display_name_key = keys::DisplayNameKey::new(namespace_id); let create_ts_key = keys::CreateTsKey::new(namespace_id); + let runner_kind_key = keys::RunnerKindKey::new(namespace_id); - let (name_entry, display_name_entry, create_ts_entry) = tokio::try_join!( - tx.get(&keys::subspace().pack(&name_key), SERIALIZABLE), - tx.get(&keys::subspace().pack(&display_name_key), SERIALIZABLE), - tx.get(&keys::subspace().pack(&create_ts_key), SERIALIZABLE), + let (name, display_name, create_ts, runner_kind) = tokio::try_join!( + txs.read_opt(&name_key, SERIALIZABLE), + txs.read_opt(&display_name_key, SERIALIZABLE), + txs.read_opt(&create_ts_key, SERIALIZABLE), + txs.read_opt(&runner_kind_key, SERIALIZABLE), )?; // Namespace not found - let Some(name_entry) = name_entry else { + let Some(name) = name else { return Ok(None); }; - let name = name_key - .deserialize(&name_entry) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?; - let display_name = display_name_key - .deserialize(&display_name_entry.ok_or(udb::FdbBindingError::CustomError( - format!("key should exist: {display_name_key:?}").into(), - ))?) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?; - let create_ts = create_ts_key - .deserialize(&create_ts_entry.ok_or(udb::FdbBindingError::CustomError( - format!("key should exist: {create_ts_key:?}").into(), - ))?) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?; + let display_name = display_name.ok_or(udb::FdbBindingError::CustomError( + format!("key should exist: {display_name_key:?}").into(), + ))?; + let create_ts = create_ts.ok_or(udb::FdbBindingError::CustomError( + format!("key should exist: {create_ts_key:?}").into(), + ))?; + let runner_kind = runner_kind.ok_or(udb::FdbBindingError::CustomError( + format!("key should exist: {runner_kind_key:?}").into(), + ))?; Ok(Some(Namespace { namespace_id, name, display_name, create_ts, + runner_kind, })) } diff --git a/packages/services/namespace/src/types.rs b/packages/services/namespace/src/types.rs index c7e6f71d59..05e924ad34 100644 --- a/packages/services/namespace/src/types.rs +++ b/packages/services/namespace/src/types.rs @@ -7,4 +7,58 @@ pub struct Namespace { pub name: String, pub display_name: String, pub create_ts: i64, + pub runner_kind: RunnerKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Hash, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum RunnerKind { + Outbound { + url: String, + slots_per_runner: u32, + min_runners: u32, + max_runners: u32, + runners_margin: u32, + }, + Custom, +} + +impl From for rivet_data::generated::namespace_runner_kind_v1::Data { + fn from(value: RunnerKind) -> Self { + match value { + RunnerKind::Outbound { + url, + slots_per_runner, + min_runners, + max_runners, + runners_margin, + } => rivet_data::generated::namespace_runner_kind_v1::Data::Outbound( + rivet_data::generated::namespace_runner_kind_v1::Outbound { + url, + slots_per_runner, + min_runners, + max_runners, + runners_margin, + }, + ), + RunnerKind::Custom => rivet_data::generated::namespace_runner_kind_v1::Data::Custom, + } + } +} + +impl From for RunnerKind { + fn from(value: rivet_data::generated::namespace_runner_kind_v1::Data) -> Self { + match value { + rivet_data::generated::namespace_runner_kind_v1::Data::Outbound(o) => { + RunnerKind::Outbound { + url: o.url, + slots_per_runner: o.slots_per_runner, + min_runners: o.min_runners, + max_runners: o.max_runners, + runners_margin: o.runners_margin, + } + } + rivet_data::generated::namespace_runner_kind_v1::Data::Custom => RunnerKind::Custom, + } + } } diff --git a/packages/services/namespace/src/workflows/namespace.rs b/packages/services/namespace/src/workflows/namespace.rs index 16575347fe..90078b23e6 100644 --- a/packages/services/namespace/src/workflows/namespace.rs +++ b/packages/services/namespace/src/workflows/namespace.rs @@ -1,10 +1,9 @@ use futures_util::FutureExt; use gas::prelude::*; use serde::{Deserialize, Serialize}; -use udb_util::{FormalKey, SERIALIZABLE}; -use universaldb as udb; +use udb_util::{SERIALIZABLE, TxnExt}; -use crate::{errors, keys}; +use crate::{errors, keys, types::RunnerKind}; #[derive(Debug, Deserialize, Serialize)] pub struct Input { @@ -59,7 +58,7 @@ pub async fn namespace(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> { // Does nothing yet ctx.repeat(|ctx| { async move { - ctx.listen::().await?; + ctx.listen::().await?; Ok(Loop::<()>::Continue) } @@ -79,7 +78,7 @@ pub struct Failed { } #[signal("namespace_update")] -pub struct NamespaceUpdate {} +pub struct Update {} #[derive(Debug, Clone, Serialize, Deserialize, Hash)] pub struct ValidateInput { @@ -156,45 +155,29 @@ async fn insert_fdb( let display_name = input.display_name.clone(); async move { - let name_key = keys::NameKey::new(namespace_id); - let name_idx_key = keys::ByNameKey::new(name.clone()); - let display_name_key = keys::DisplayNameKey::new(namespace_id); - let create_ts_key = keys::CreateTsKey::new(namespace_id); + let txs = tx.subspace(keys::subspace()); - let name_idx_entry = tx - .get(&keys::subspace().pack(&name_idx_key), SERIALIZABLE) - .await?; + let name_idx_key = keys::ByNameKey::new(name.clone()); - if name_idx_entry.is_some() { + if txs.exists(&name_idx_key, SERIALIZABLE).await? { return Ok(Err(errors::Namespace::NameNotUnique)); } - tx.set( - &keys::subspace().pack(&name_key), - &name_key - .serialize(name) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?, - ); - tx.set( - &keys::subspace().pack(&display_name_key), - &display_name_key - .serialize(display_name) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?, - ); - tx.set( - &keys::subspace().pack(&create_ts_key), - &create_ts_key - .serialize(input.create_ts) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?, - ); + txs.write(&keys::NameKey::new(namespace_id), name)?; + txs.write(&keys::DisplayNameKey::new(namespace_id), display_name)?; + txs.write(&keys::CreateTsKey::new(namespace_id), input.create_ts)?; + txs.write(&keys::RunnerKindKey::new(namespace_id), RunnerKind::Custom)?; + + // RunnerKind::Outbound { + // url: "http://runner:5051/start".to_string(), + // slots_per_runner: 10, + // min_runners: 1, + // max_runners: 1, + // runners_margin: 0, + // } // Insert idx - tx.set( - &keys::subspace().pack(&name_idx_key), - &name_idx_key - .serialize(namespace_id) - .map_err(|x| udb::FdbBindingError::CustomError(x.into()))?, - ); + txs.write(&name_idx_key, namespace_id)?; Ok(Ok(())) } diff --git a/packages/services/pegboard/Cargo.toml b/packages/services/pegboard/Cargo.toml index 02e85d00e6..945f36bdfa 100644 --- a/packages/services/pegboard/Cargo.toml +++ b/packages/services/pegboard/Cargo.toml @@ -16,7 +16,7 @@ rivet-api-client.workspace = true rivet-api-types.workspace = true rivet-api-util.workspace = true rivet-error.workspace = true -rivet-key-data.workspace = true +rivet-data.workspace = true rivet-metrics.workspace = true rivet-runner-protocol.workspace = true rivet-types.workspace = true diff --git a/packages/services/pegboard/src/keys/datacenter.rs b/packages/services/pegboard/src/keys/datacenter.rs deleted file mode 100644 index bd89e1db22..0000000000 --- a/packages/services/pegboard/src/keys/datacenter.rs +++ /dev/null @@ -1,249 +0,0 @@ -use std::result::Result::Ok; - -use anyhow::*; -use gas::prelude::*; -use udb_util::prelude::*; -use versioned_data_util::OwnedVersionedData; - -#[derive(Debug)] -pub struct RunnerAllocIdxKey { - pub namespace_id: Id, - pub name: String, - pub version: u32, - pub remaining_millislots: u32, - pub last_ping_ts: i64, - pub runner_id: Id, -} - -impl RunnerAllocIdxKey { - pub fn new( - namespace_id: Id, - name: String, - version: u32, - remaining_millislots: u32, - last_ping_ts: i64, - runner_id: Id, - ) -> Self { - RunnerAllocIdxKey { - namespace_id, - name, - version, - remaining_millislots, - last_ping_ts, - runner_id, - } - } - - pub fn subspace(namespace_id: Id, name: String) -> RunnerAllocIdxSubspaceKey { - RunnerAllocIdxSubspaceKey::new(namespace_id, name) - } - - pub fn entire_subspace() -> RunnerAllocIdxSubspaceKey { - RunnerAllocIdxSubspaceKey::entire() - } -} - -impl FormalKey for RunnerAllocIdxKey { - type Value = rivet_key_data::converted::RunnerAllocIdxKeyData; - - fn deserialize(&self, raw: &[u8]) -> Result { - rivet_key_data::versioned::RunnerAllocIdxKeyData::deserialize_with_embedded_version(raw)? - .try_into() - } - - fn serialize(&self, value: Self::Value) -> Result> { - rivet_key_data::versioned::RunnerAllocIdxKeyData::latest(value.try_into()?) - .serialize_with_embedded_version( - rivet_key_data::PEGBOARD_DATACENTER_RUNNER_ALLOC_IDX_VERSION, - ) - } -} - -impl TuplePack for RunnerAllocIdxKey { - fn pack( - &self, - w: &mut W, - tuple_depth: TupleDepth, - ) -> std::io::Result { - let t = ( - DATACENTER, - RUNNER_ALLOC_IDX, - self.namespace_id, - &self.name, - // Stored in reverse order (higher versions are first) - -(self.version as i32), - // Stored in reverse order (higher remaining slots are first) - -(self.remaining_millislots as i32), - self.last_ping_ts, - self.runner_id, - ); - t.pack(w, tuple_depth) - } -} - -impl<'de> TupleUnpack<'de> for RunnerAllocIdxKey { - fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { - let ( - input, - (_, _, namespace_id, name, version, remaining_millislots, last_ping_ts, runner_id), - ) = <(usize, usize, Id, String, i32, i32, i64, Id)>::unpack(input, tuple_depth)?; - - let v = RunnerAllocIdxKey { - namespace_id, - name, - version: -version as u32, - remaining_millislots: -remaining_millislots as u32, - last_ping_ts, - runner_id, - }; - - Ok((input, v)) - } -} - -pub struct RunnerAllocIdxSubspaceKey { - pub namespace_id: Option, - pub name: Option, -} - -impl RunnerAllocIdxSubspaceKey { - pub fn new(namespace_id: Id, name: String) -> Self { - RunnerAllocIdxSubspaceKey { - namespace_id: Some(namespace_id), - name: Some(name), - } - } - - pub fn entire() -> Self { - RunnerAllocIdxSubspaceKey { - namespace_id: None, - name: None, - } - } -} - -impl TuplePack for RunnerAllocIdxSubspaceKey { - fn pack( - &self, - w: &mut W, - tuple_depth: TupleDepth, - ) -> std::io::Result { - let mut offset = VersionstampOffset::None { size: 0 }; - - let t = (DATACENTER, RUNNER_ALLOC_IDX); - offset += t.pack(w, tuple_depth)?; - - if let Some(namespace_id) = &self.namespace_id { - offset += namespace_id.pack(w, tuple_depth)?; - - if let Some(name) = &self.name { - offset += name.pack(w, tuple_depth)?; - } - } - - Ok(offset) - } -} - -#[derive(Debug)] -pub struct PendingActorByRunnerNameSelectorKey { - pub namespace_id: Id, - pub runner_name_selector: String, - pub ts: i64, - pub actor_id: Id, -} - -impl PendingActorByRunnerNameSelectorKey { - pub fn new(namespace_id: Id, runner_name_selector: String, ts: i64, actor_id: Id) -> Self { - PendingActorByRunnerNameSelectorKey { - namespace_id, - runner_name_selector, - ts, - actor_id, - } - } - - pub fn subspace( - namespace_id: Id, - runner_name_selector: String, - ) -> PendingActorByRunnerNameSelectorSubspaceKey { - PendingActorByRunnerNameSelectorSubspaceKey::new(namespace_id, runner_name_selector) - } -} - -impl FormalKey for PendingActorByRunnerNameSelectorKey { - /// Generation. - type Value = u32; - - fn deserialize(&self, raw: &[u8]) -> Result { - Ok(u32::from_be_bytes(raw.try_into()?)) - } - - fn serialize(&self, value: Self::Value) -> Result> { - Ok(value.to_be_bytes().to_vec()) - } -} - -impl TuplePack for PendingActorByRunnerNameSelectorKey { - fn pack( - &self, - w: &mut W, - tuple_depth: TupleDepth, - ) -> std::io::Result { - let t = ( - DATACENTER, - PENDING_ACTOR_BY_RUNNER_NAME_SELECTOR, - self.namespace_id, - &self.runner_name_selector, - self.ts, - self.actor_id, - ); - t.pack(w, tuple_depth) - } -} - -impl<'de> TupleUnpack<'de> for PendingActorByRunnerNameSelectorKey { - fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { - let (input, (_, _, namespace_id, runner_name_selector, ts, actor_id)) = - <(usize, usize, Id, String, i64, Id)>::unpack(input, tuple_depth)?; - - let v = PendingActorByRunnerNameSelectorKey { - namespace_id, - runner_name_selector, - ts, - actor_id, - }; - - Ok((input, v)) - } -} - -pub struct PendingActorByRunnerNameSelectorSubspaceKey { - pub namespace_id: Id, - pub runner_name_selector: String, -} - -impl PendingActorByRunnerNameSelectorSubspaceKey { - pub fn new(namespace_id: Id, runner_name_selector: String) -> Self { - PendingActorByRunnerNameSelectorSubspaceKey { - namespace_id, - runner_name_selector, - } - } -} - -impl TuplePack for PendingActorByRunnerNameSelectorSubspaceKey { - fn pack( - &self, - w: &mut W, - tuple_depth: TupleDepth, - ) -> std::io::Result { - let t = ( - DATACENTER, - PENDING_ACTOR_BY_RUNNER_NAME_SELECTOR, - self.namespace_id, - &self.runner_name_selector, - ); - t.pack(w, tuple_depth) - } -} diff --git a/packages/services/pegboard/src/keys/mod.rs b/packages/services/pegboard/src/keys/mod.rs index 9e93b8983a..402214f8a0 100644 --- a/packages/services/pegboard/src/keys/mod.rs +++ b/packages/services/pegboard/src/keys/mod.rs @@ -1,7 +1,6 @@ use udb_util::prelude::*; pub mod actor; -pub mod datacenter; pub mod epoxy; pub mod ns; pub mod runner; diff --git a/packages/services/pegboard/src/keys/ns.rs b/packages/services/pegboard/src/keys/ns.rs index dd1aea42a9..236c61bc9c 100644 --- a/packages/services/pegboard/src/keys/ns.rs +++ b/packages/services/pegboard/src/keys/ns.rs @@ -5,6 +5,249 @@ use gas::prelude::*; use udb_util::prelude::*; use versioned_data_util::OwnedVersionedData; +#[derive(Debug)] +pub struct RunnerAllocIdxKey { + pub namespace_id: Id, + pub name: String, + pub version: u32, + pub remaining_millislots: u32, + pub last_ping_ts: i64, + pub runner_id: Id, +} + +impl RunnerAllocIdxKey { + pub fn new( + namespace_id: Id, + name: String, + version: u32, + remaining_millislots: u32, + last_ping_ts: i64, + runner_id: Id, + ) -> Self { + RunnerAllocIdxKey { + namespace_id, + name, + version, + remaining_millislots, + last_ping_ts, + runner_id, + } + } + + pub fn subspace(namespace_id: Id, name: String) -> RunnerAllocIdxSubspaceKey { + RunnerAllocIdxSubspaceKey::new(namespace_id, name) + } + + pub fn entire_subspace() -> RunnerAllocIdxSubspaceKey { + RunnerAllocIdxSubspaceKey::entire() + } +} + +impl FormalKey for RunnerAllocIdxKey { + type Value = rivet_data::converted::RunnerAllocIdxKeyData; + + fn deserialize(&self, raw: &[u8]) -> Result { + rivet_data::versioned::RunnerAllocIdxKeyData::deserialize_with_embedded_version(raw)? + .try_into() + } + + fn serialize(&self, value: Self::Value) -> Result> { + rivet_data::versioned::RunnerAllocIdxKeyData::latest(value.try_into()?) + .serialize_with_embedded_version( + rivet_data::PEGBOARD_NAMESPACE_RUNNER_ALLOC_IDX_VERSION, + ) + } +} + +impl TuplePack for RunnerAllocIdxKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = ( + NAMESPACE, + RUNNER_ALLOC_IDX, + self.namespace_id, + &self.name, + // Stored in reverse order (higher versions are first) + -(self.version as i32), + // Stored in reverse order (higher remaining slots are first) + -(self.remaining_millislots as i32), + self.last_ping_ts, + self.runner_id, + ); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for RunnerAllocIdxKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let ( + input, + (_, _, namespace_id, name, version, remaining_millislots, last_ping_ts, runner_id), + ) = <(usize, usize, Id, String, i32, i32, i64, Id)>::unpack(input, tuple_depth)?; + + let v = RunnerAllocIdxKey { + namespace_id, + name, + version: -version as u32, + remaining_millislots: -remaining_millislots as u32, + last_ping_ts, + runner_id, + }; + + Ok((input, v)) + } +} + +pub struct RunnerAllocIdxSubspaceKey { + pub namespace_id: Option, + pub name: Option, +} + +impl RunnerAllocIdxSubspaceKey { + pub fn new(namespace_id: Id, name: String) -> Self { + RunnerAllocIdxSubspaceKey { + namespace_id: Some(namespace_id), + name: Some(name), + } + } + + pub fn entire() -> Self { + RunnerAllocIdxSubspaceKey { + namespace_id: None, + name: None, + } + } +} + +impl TuplePack for RunnerAllocIdxSubspaceKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let mut offset = VersionstampOffset::None { size: 0 }; + + let t = (NAMESPACE, RUNNER_ALLOC_IDX); + offset += t.pack(w, tuple_depth)?; + + if let Some(namespace_id) = &self.namespace_id { + offset += namespace_id.pack(w, tuple_depth)?; + + if let Some(name) = &self.name { + offset += name.pack(w, tuple_depth)?; + } + } + + Ok(offset) + } +} + +#[derive(Debug)] +pub struct PendingActorByRunnerNameSelectorKey { + pub namespace_id: Id, + pub runner_name_selector: String, + pub ts: i64, + pub actor_id: Id, +} + +impl PendingActorByRunnerNameSelectorKey { + pub fn new(namespace_id: Id, runner_name_selector: String, ts: i64, actor_id: Id) -> Self { + PendingActorByRunnerNameSelectorKey { + namespace_id, + runner_name_selector, + ts, + actor_id, + } + } + + pub fn subspace( + namespace_id: Id, + runner_name_selector: String, + ) -> PendingActorByRunnerNameSelectorSubspaceKey { + PendingActorByRunnerNameSelectorSubspaceKey::new(namespace_id, runner_name_selector) + } +} + +impl FormalKey for PendingActorByRunnerNameSelectorKey { + /// Generation. + type Value = u32; + + fn deserialize(&self, raw: &[u8]) -> Result { + Ok(u32::from_be_bytes(raw.try_into()?)) + } + + fn serialize(&self, value: Self::Value) -> Result> { + Ok(value.to_be_bytes().to_vec()) + } +} + +impl TuplePack for PendingActorByRunnerNameSelectorKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = ( + NAMESPACE, + PENDING_ACTOR_BY_RUNNER_NAME_SELECTOR, + self.namespace_id, + &self.runner_name_selector, + self.ts, + self.actor_id, + ); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for PendingActorByRunnerNameSelectorKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let (input, (_, _, namespace_id, runner_name_selector, ts, actor_id)) = + <(usize, usize, Id, String, i64, Id)>::unpack(input, tuple_depth)?; + + let v = PendingActorByRunnerNameSelectorKey { + namespace_id, + runner_name_selector, + ts, + actor_id, + }; + + Ok((input, v)) + } +} + +pub struct PendingActorByRunnerNameSelectorSubspaceKey { + pub namespace_id: Id, + pub runner_name_selector: String, +} + +impl PendingActorByRunnerNameSelectorSubspaceKey { + pub fn new(namespace_id: Id, runner_name_selector: String) -> Self { + PendingActorByRunnerNameSelectorSubspaceKey { + namespace_id, + runner_name_selector, + } + } +} + +impl TuplePack for PendingActorByRunnerNameSelectorSubspaceKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = ( + NAMESPACE, + PENDING_ACTOR_BY_RUNNER_NAME_SELECTOR, + self.namespace_id, + &self.runner_name_selector, + ); + t.pack(w, tuple_depth) + } +} + #[derive(Debug)] pub struct ActiveActorKey { namespace_id: Id, @@ -320,18 +563,15 @@ impl ActorByKeyKey { } impl FormalKey for ActorByKeyKey { - type Value = rivet_key_data::converted::ActorByKeyKeyData; + type Value = rivet_data::converted::ActorByKeyKeyData; fn deserialize(&self, raw: &[u8]) -> Result { - rivet_key_data::versioned::ActorByKeyKeyData::deserialize_with_embedded_version(raw)? - .try_into() + rivet_data::versioned::ActorByKeyKeyData::deserialize_with_embedded_version(raw)?.try_into() } fn serialize(&self, value: Self::Value) -> Result> { - rivet_key_data::versioned::ActorByKeyKeyData::latest(value.try_into()?) - .serialize_with_embedded_version( - rivet_key_data::PEGBOARD_NAMESPACE_ACTOR_BY_KEY_VERSION, - ) + rivet_data::versioned::ActorByKeyKeyData::latest(value.try_into()?) + .serialize_with_embedded_version(rivet_data::PEGBOARD_NAMESPACE_ACTOR_BY_KEY_VERSION) } } @@ -938,18 +1178,16 @@ impl RunnerByKeyKey { } impl FormalKey for RunnerByKeyKey { - type Value = rivet_key_data::converted::RunnerByKeyKeyData; + type Value = rivet_data::converted::RunnerByKeyKeyData; fn deserialize(&self, raw: &[u8]) -> Result { - rivet_key_data::versioned::RunnerByKeyKeyData::deserialize_with_embedded_version(raw)? + rivet_data::versioned::RunnerByKeyKeyData::deserialize_with_embedded_version(raw)? .try_into() } fn serialize(&self, value: Self::Value) -> Result> { - rivet_key_data::versioned::RunnerByKeyKeyData::latest(value.try_into()?) - .serialize_with_embedded_version( - rivet_key_data::PEGBOARD_NAMESPACE_RUNNER_BY_KEY_VERSION, - ) + rivet_data::versioned::RunnerByKeyKeyData::latest(value.try_into()?) + .serialize_with_embedded_version(rivet_data::PEGBOARD_NAMESPACE_RUNNER_BY_KEY_VERSION) } } @@ -1002,16 +1240,15 @@ impl ActorNameKey { } impl FormalKey for ActorNameKey { - type Value = rivet_key_data::converted::ActorNameKeyData; + type Value = rivet_data::converted::ActorNameKeyData; fn deserialize(&self, raw: &[u8]) -> Result { - rivet_key_data::versioned::ActorNameKeyData::deserialize_with_embedded_version(raw)? - .try_into() + rivet_data::versioned::ActorNameKeyData::deserialize_with_embedded_version(raw)?.try_into() } fn serialize(&self, value: Self::Value) -> Result> { - rivet_key_data::versioned::ActorNameKeyData::latest(value.try_into()?) - .serialize_with_embedded_version(rivet_key_data::PEGBOARD_NAMESPACE_ACTOR_NAME_VERSION) + rivet_data::versioned::ActorNameKeyData::latest(value.try_into()?) + .serialize_with_embedded_version(rivet_data::PEGBOARD_NAMESPACE_ACTOR_NAME_VERSION) } } @@ -1128,3 +1365,87 @@ impl TuplePack for RunnerNameSubspaceKey { t.pack(w, tuple_depth) } } + +#[derive(Debug)] +pub struct OutboundDesiredSlotsKey { + pub namespace_id: Id, + pub runner_name_selector: String, +} + +impl OutboundDesiredSlotsKey { + pub fn new(namespace_id: Id, runner_name_selector: String) -> Self { + OutboundDesiredSlotsKey { + namespace_id, + runner_name_selector, + } + } + + pub fn subspace() -> OutboundDesiredSlotsSubspaceKey { + OutboundDesiredSlotsSubspaceKey::new() + } +} + +impl FormalKey for OutboundDesiredSlotsKey { + /// Count. + type Value = u32; + + fn deserialize(&self, raw: &[u8]) -> Result { + // NOTE: Atomic ops use little endian + Ok(u32::from_le_bytes(raw.try_into()?)) + } + + fn serialize(&self, value: Self::Value) -> Result> { + // NOTE: Atomic ops use little endian + Ok(value.to_le_bytes().to_vec()) + } +} + +impl TuplePack for OutboundDesiredSlotsKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = ( + NAMESPACE, + OUTBOUND, + DESIRED_SLOTS, + self.namespace_id, + &self.runner_name_selector, + ); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for OutboundDesiredSlotsKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let (input, (_, _, namespace_id, runner_name_selector)) = + <(usize, usize, Id, String)>::unpack(input, tuple_depth)?; + + let v = OutboundDesiredSlotsKey { + namespace_id, + runner_name_selector, + }; + + Ok((input, v)) + } +} + +pub struct OutboundDesiredSlotsSubspaceKey {} + +impl OutboundDesiredSlotsSubspaceKey { + pub fn new() -> Self { + OutboundDesiredSlotsSubspaceKey {} + } +} + +impl TuplePack for OutboundDesiredSlotsSubspaceKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = (NAMESPACE, OUTBOUND, DESIRED_SLOTS); + t.pack(w, tuple_depth) + } +} diff --git a/packages/services/pegboard/src/keys/runner.rs b/packages/services/pegboard/src/keys/runner.rs index 9e83aac404..528ba2f6b4 100644 --- a/packages/services/pegboard/src/keys/runner.rs +++ b/packages/services/pegboard/src/keys/runner.rs @@ -524,15 +524,15 @@ impl AddressKey { } impl FormalKey for AddressKey { - type Value = ::Latest; + type Value = ::Latest; fn deserialize(&self, raw: &[u8]) -> Result { - rivet_key_data::versioned::AddressKeyData::deserialize_with_embedded_version(raw) + rivet_data::versioned::AddressKeyData::deserialize_with_embedded_version(raw) } fn serialize(&self, value: Self::Value) -> Result> { - rivet_key_data::versioned::AddressKeyData::latest(value) - .serialize_with_embedded_version(rivet_key_data::PEGBOARD_RUNNER_ADDRESS_VERSION) + rivet_data::versioned::AddressKeyData::latest(value) + .serialize_with_embedded_version(rivet_data::PEGBOARD_RUNNER_ADDRESS_VERSION) } } @@ -816,7 +816,7 @@ impl MetadataKey { impl FormalChunkedKey for MetadataKey { type ChunkKey = MetadataChunkKey; - type Value = rivet_key_data::converted::MetadataKeyData; + type Value = rivet_data::converted::MetadataKeyData; fn chunk(&self, chunk: usize) -> Self::ChunkKey { MetadataChunkKey { @@ -826,7 +826,7 @@ impl FormalChunkedKey for MetadataKey { } fn combine(&self, chunks: Vec) -> Result { - rivet_key_data::versioned::MetadataKeyData::deserialize_with_embedded_version( + rivet_data::versioned::MetadataKeyData::deserialize_with_embedded_version( &chunks .iter() .map(|x| x.value().iter().map(|x| *x)) @@ -838,8 +838,8 @@ impl FormalChunkedKey for MetadataKey { fn split(&self, value: Self::Value) -> Result>> { Ok( - rivet_key_data::versioned::MetadataKeyData::latest(value.try_into()?) - .serialize_with_embedded_version(rivet_key_data::PEGBOARD_RUNNER_METADATA_VERSION)? + rivet_data::versioned::MetadataKeyData::latest(value.try_into()?) + .serialize_with_embedded_version(rivet_data::PEGBOARD_RUNNER_METADATA_VERSION)? .chunks(udb_util::CHUNK_SIZE) .map(|x| x.to_vec()) .collect(), diff --git a/packages/services/pegboard/src/lib.rs b/packages/services/pegboard/src/lib.rs index 8a08a5b9a9..b5dd33dd0a 100644 --- a/packages/services/pegboard/src/lib.rs +++ b/packages/services/pegboard/src/lib.rs @@ -2,6 +2,7 @@ use gas::prelude::*; pub mod errors; pub mod keys; +pub mod messages; mod metrics; pub mod ops; pub mod pubsub_subjects; diff --git a/packages/services/pegboard/src/messages.rs b/packages/services/pegboard/src/messages.rs new file mode 100644 index 0000000000..e3ad78680d --- /dev/null +++ b/packages/services/pegboard/src/messages.rs @@ -0,0 +1,4 @@ +use gas::prelude::*; + +#[message("pegboard_bump_outbound_autoscaler")] +pub struct BumpOutboundAutoscaler {} diff --git a/packages/services/pegboard/src/ops/actor/create.rs b/packages/services/pegboard/src/ops/actor/create.rs index 5dd72c7c75..66a7a1b42c 100644 --- a/packages/services/pegboard/src/ops/actor/create.rs +++ b/packages/services/pegboard/src/ops/actor/create.rs @@ -125,8 +125,12 @@ async fn forward_to_datacenter( // Get namespace name for the remote call let namespace = ctx - .op(namespace::ops::get_global::Input { namespace_id }) + .op(namespace::ops::get_global::Input { + namespace_ids: vec![namespace_id], + }) .await? + .into_iter() + .next() .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; // Generate a new actor ID with the correct datacenter label diff --git a/packages/services/pegboard/src/ops/actor/get_for_key.rs b/packages/services/pegboard/src/ops/actor/get_for_key.rs index d6960e5689..f850f88c69 100644 --- a/packages/services/pegboard/src/ops/actor/get_for_key.rs +++ b/packages/services/pegboard/src/ops/actor/get_for_key.rs @@ -60,9 +60,11 @@ pub async fn pegboard_actor_get_for_key(ctx: &OperationCtx, input: &Input) -> Re // Get namespace name for the remote call let namespace = ctx .op(namespace::ops::get_global::Input { - namespace_id: input.namespace_id, + namespace_ids: vec![input.namespace_id], }) .await? + .into_iter() + .next() .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; // Make request to remote datacenter diff --git a/packages/services/pegboard/src/ops/actor/list_names.rs b/packages/services/pegboard/src/ops/actor/list_names.rs index 727a0bed6e..8fe72d2100 100644 --- a/packages/services/pegboard/src/ops/actor/list_names.rs +++ b/packages/services/pegboard/src/ops/actor/list_names.rs @@ -1,6 +1,6 @@ use futures_util::{StreamExt, TryStreamExt}; use gas::prelude::*; -use rivet_key_data::converted::ActorNameKeyData; +use rivet_data::converted::ActorNameKeyData; use udb_util::{SNAPSHOT, TxnExt}; use universaldb::{self as udb, options::StreamingMode}; diff --git a/packages/services/pegboard/src/ops/runner/get.rs b/packages/services/pegboard/src/ops/runner/get.rs index d967eccee3..22cbbcc6e4 100644 --- a/packages/services/pegboard/src/ops/runner/get.rs +++ b/packages/services/pegboard/src/ops/runner/get.rs @@ -1,7 +1,7 @@ use anyhow::Result; use futures_util::TryStreamExt; use gas::prelude::*; -use rivet_key_data::generated::pegboard_runner_address_v1::Data as AddressKeyData; +use rivet_data::generated::pegboard_runner_address_v1::Data as AddressKeyData; use rivet_types::runners::Runner; use udb_util::{FormalChunkedKey, SERIALIZABLE, SNAPSHOT, TxnExt}; use universaldb::{self as udb, options::StreamingMode}; diff --git a/packages/services/pegboard/src/ops/runner/update_alloc_idx.rs b/packages/services/pegboard/src/ops/runner/update_alloc_idx.rs index 7b94ce7f8f..39519590df 100644 --- a/packages/services/pegboard/src/ops/runner/update_alloc_idx.rs +++ b/packages/services/pegboard/src/ops/runner/update_alloc_idx.rs @@ -121,7 +121,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) let remaining_millislots = (remaining_slots * 1000) / total_slots; - let old_alloc_key = keys::datacenter::RunnerAllocIdxKey::new( + let old_alloc_key = keys::ns::RunnerAllocIdxKey::new( namespace_id, name.clone(), version, @@ -140,7 +140,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) Action::AddIdx => { txs.write( &old_alloc_key, - rivet_key_data::converted::RunnerAllocIdxKeyData { + rivet_data::converted::RunnerAllocIdxKeyData { workflow_id, remaining_slots, total_slots, @@ -162,7 +162,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) txs.delete(&old_alloc_key); txs.write( - &keys::datacenter::RunnerAllocIdxKey::new( + &keys::ns::RunnerAllocIdxKey::new( namespace_id, name.clone(), version, @@ -170,7 +170,7 @@ pub async fn pegboard_runner_update_alloc_idx(ctx: &OperationCtx, input: &Input) last_ping_ts, runner.runner_id, ), - rivet_key_data::converted::RunnerAllocIdxKeyData { + rivet_data::converted::RunnerAllocIdxKeyData { workflow_id, remaining_slots, total_slots, diff --git a/packages/services/pegboard/src/workflows/actor/actor_keys.rs b/packages/services/pegboard/src/workflows/actor/actor_keys.rs index 1ff028160a..e5cc89a10c 100644 --- a/packages/services/pegboard/src/workflows/actor/actor_keys.rs +++ b/packages/services/pegboard/src/workflows/actor/actor_keys.rs @@ -4,7 +4,7 @@ use epoxy::{ }; use futures_util::TryStreamExt; use gas::prelude::*; -use rivet_key_data::converted::ActorByKeyKeyData; +use rivet_data::converted::ActorByKeyKeyData; use udb_util::prelude::*; use universaldb::{self as udb, FdbBindingError, options::StreamingMode}; diff --git a/packages/services/pegboard/src/workflows/actor/destroy.rs b/packages/services/pegboard/src/workflows/actor/destroy.rs index c408ef6500..44862d219d 100644 --- a/packages/services/pegboard/src/workflows/actor/destroy.rs +++ b/packages/services/pegboard/src/workflows/actor/destroy.rs @@ -1,8 +1,9 @@ use gas::prelude::*; -use rivet_key_data::converted::ActorByKeyKeyData; +use namespace::types::RunnerKind; +use rivet_data::converted::ActorByKeyKeyData; use rivet_runner_protocol::protocol; use udb_util::{SERIALIZABLE, TxnExt}; -use universaldb as udb; +use universaldb::{self as udb, options::MutationType}; use super::{DestroyComplete, DestroyStarted, State}; @@ -85,6 +86,7 @@ async fn update_state_and_fdb( state.namespace_id, &state.runner_name_selector, runner_id, + &state.ns_runner_kind, &tx, ) .await?; @@ -162,6 +164,7 @@ pub(crate) async fn clear_slot( namespace_id: Id, runner_name_selector: &str, runner_id: Id, + ns_runner_kind: &RunnerKind, tx: &udb::RetryableTransaction, ) -> Result<(), udb::FdbBindingError> { let txs = tx.subspace(keys::subspace()); @@ -198,7 +201,7 @@ pub(crate) async fn clear_slot( // Write new remaining slots txs.write(&runner_remaining_slots_key, new_runner_remaining_slots)?; - let old_runner_alloc_key = keys::datacenter::RunnerAllocIdxKey::new( + let old_runner_alloc_key = keys::ns::RunnerAllocIdxKey::new( namespace_id, runner_name_selector.to_string(), runner_version, @@ -213,7 +216,7 @@ pub(crate) async fn clear_slot( txs.delete(&old_runner_alloc_key); let new_remaining_millislots = (new_runner_remaining_slots * 1000) / runner_total_slots; - let new_runner_alloc_key = keys::datacenter::RunnerAllocIdxKey::new( + let new_runner_alloc_key = keys::ns::RunnerAllocIdxKey::new( namespace_id, runner_name_selector.to_string(), runner_version, @@ -224,7 +227,7 @@ pub(crate) async fn clear_slot( txs.write( &new_runner_alloc_key, - rivet_key_data::converted::RunnerAllocIdxKeyData { + rivet_data::converted::RunnerAllocIdxKeyData { workflow_id: runner_workflow_id, remaining_slots: new_runner_remaining_slots, total_slots: runner_total_slots, @@ -232,6 +235,14 @@ pub(crate) async fn clear_slot( )?; } + if let RunnerKind::Outbound { .. } = ns_runner_kind { + txs.atomic_op( + &keys::ns::OutboundDesiredSlotsKey::new(namespace_id, runner_name_selector.to_string()), + &(-1i32).to_le_bytes(), + MutationType::Add, + ); + } + Ok(()) } diff --git a/packages/services/pegboard/src/workflows/actor/mod.rs b/packages/services/pegboard/src/workflows/actor/mod.rs index a512312465..3bff9c38fb 100644 --- a/packages/services/pegboard/src/workflows/actor/mod.rs +++ b/packages/services/pegboard/src/workflows/actor/mod.rs @@ -1,5 +1,6 @@ use futures_util::FutureExt; use gas::prelude::*; +use namespace::types::RunnerKind; use rivet_runner_protocol::protocol; use rivet_types::actors::CrashPolicy; @@ -45,6 +46,9 @@ pub struct State { pub create_ts: i64, pub create_complete_ts: Option, + + pub ns_runner_kind: RunnerKind, + pub start_ts: Option, // NOTE: This is not the alarm ts, this is when the actor started sleeping. See `LifecycleState` for alarm pub sleep_ts: Option, @@ -66,6 +70,7 @@ impl State { runner_name_selector: String, crash_policy: CrashPolicy, create_ts: i64, + ns_runner_kind: RunnerKind, ) -> Self { State { name, @@ -78,6 +83,8 @@ impl State { create_ts, create_complete_ts: None, + ns_runner_kind, + start_ts: None, pending_allocation_ts: None, sleep_ts: None, @@ -115,15 +122,18 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> }) .await?; - if let Err(error) = validation_res { - ctx.msg(Failed { error }) - .tag("actor_id", input.actor_id) - .send() - .await?; + let metadata = match validation_res { + Ok(metadata) => metadata, + Err(error) => { + ctx.msg(Failed { error }) + .tag("actor_id", input.actor_id) + .send() + .await?; - // TODO(RVT-3928): return Ok(Err); - return Ok(()); - } + // TODO(RVT-3928): return Ok(Err); + return Ok(()); + } + }; ctx.activity(setup::InitStateAndUdbInput { actor_id: input.actor_id, @@ -133,6 +143,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> runner_name_selector: input.runner_name_selector.clone(), crash_policy: input.crash_policy, create_ts: ctx.create_ts(), + ns_runner_kind: metadata.ns_runner_kind, }) .await?; @@ -156,6 +167,19 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .tag("actor_id", input.actor_id) .send() .await?; + + // Destroyed early + ctx.workflow(destroy::Input { + namespace_id: input.namespace_id, + actor_id: input.actor_id, + name: input.name.clone(), + key: input.key.clone(), + generation: 0, + kill: false, + }) + .output() + .await?; + return Ok(()); } actor_keys::ReserveKeyOutput::KeyExists { existing_actor_id } => { @@ -170,8 +194,6 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .await?; // Destroyed early - // - // This will also deallocate any key that was already allocated to Epoxy ctx.workflow(destroy::Input { namespace_id: input.namespace_id, actor_id: input.actor_id, @@ -335,7 +357,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> state.alarm_ts = None; state.sleeping = false; - if runtime::reschedule_actor(ctx, &input, state, true).await? { + if runtime::reschedule_actor(ctx, &input, state).await? { // Destroyed early return Ok(Loop::Break(runtime::LifecycleRes { generation: state.generation, @@ -466,7 +488,7 @@ async fn handle_stopped( .await?; } - if runtime::reschedule_actor(ctx, &input, state, false).await? { + if runtime::reschedule_actor(ctx, &input, state).await? { // Destroyed early return Ok(Some(runtime::LifecycleRes { generation: state.generation, diff --git a/packages/services/pegboard/src/workflows/actor/runtime.rs b/packages/services/pegboard/src/workflows/actor/runtime.rs index 6a6a93a14e..b1cb7887e8 100644 --- a/packages/services/pegboard/src/workflows/actor/runtime.rs +++ b/packages/services/pegboard/src/workflows/actor/runtime.rs @@ -3,18 +3,16 @@ use std::time::Instant; use futures_util::StreamExt; use futures_util::{FutureExt, TryStreamExt}; use gas::prelude::*; +use namespace::types::RunnerKind; use rivet_metrics::KeyValue; use rivet_runner_protocol::protocol; use udb_util::{FormalKey, SERIALIZABLE, SNAPSHOT, TxnExt}; use universaldb::{ self as udb, - options::{ConflictRangeType, StreamingMode}, + options::{ConflictRangeType, MutationType, StreamingMode}, }; -use crate::{ - keys, metrics, - workflows::runner::{AllocatePendingActorsInput, RUNNER_ELIGIBLE_THRESHOLD_MS}, -}; +use crate::{keys, metrics, workflows::runner::RUNNER_ELIGIBLE_THRESHOLD_MS}; use super::{ ACTOR_START_THRESHOLD_MS, Allocate, BASE_RETRY_TIMEOUT_MS, Destroy, Input, PendingAllocation, @@ -105,6 +103,7 @@ async fn allocate_actor( let start_instant = Instant::now(); let mut state = ctx.state::()?; let namespace_id = state.namespace_id; + let ns_runner_kind = &state.ns_runner_kind; // NOTE: This txn should closely resemble the one found in the allocate_pending_actors activity of the // client wf @@ -114,13 +113,24 @@ async fn allocate_actor( let ping_threshold_ts = util::timestamp::now() - RUNNER_ELIGIBLE_THRESHOLD_MS; let txs = tx.subspace(keys::subspace()); + // Increment desired slots if namespace has an outbound runner kind + if let RunnerKind::Outbound { .. } = ns_runner_kind { + txs.atomic_op( + &keys::ns::OutboundDesiredSlotsKey::new( + namespace_id, + input.runner_name_selector.clone(), + ), + &1u32.to_le_bytes(), + MutationType::Add, + ); + } + // Check if a queue exists - let pending_actor_subspace = txs.subspace( - &keys::datacenter::PendingActorByRunnerNameSelectorKey::subspace( + let pending_actor_subspace = + txs.subspace(&keys::ns::PendingActorByRunnerNameSelectorKey::subspace( namespace_id, input.runner_name_selector.clone(), - ), - ); + )); let queue_exists = txs .get_ranges_keyvalues( udb::RangeOption { @@ -137,11 +147,10 @@ async fn allocate_actor( .is_some(); if !queue_exists { - let runner_alloc_subspace = - txs.subspace(&keys::datacenter::RunnerAllocIdxKey::subspace( - namespace_id, - input.runner_name_selector.clone(), - )); + let runner_alloc_subspace = txs.subspace(&keys::ns::RunnerAllocIdxKey::subspace( + namespace_id, + input.runner_name_selector.clone(), + )); let mut stream = txs.get_ranges_keyvalues( udb::RangeOption { @@ -161,7 +170,7 @@ async fn allocate_actor( }; let (old_runner_alloc_key, old_runner_alloc_key_data) = - txs.read_entry::(&entry)?; + txs.read_entry::(&entry)?; if let Some(highest_version) = highest_version { // We have passed all of the runners with the highest version. This is reachable if @@ -196,7 +205,7 @@ async fn allocate_actor( // Write new allocation key with 1 less slot txs.write( - &keys::datacenter::RunnerAllocIdxKey::new( + &keys::ns::RunnerAllocIdxKey::new( namespace_id, input.runner_name_selector.clone(), old_runner_alloc_key.version, @@ -204,7 +213,7 @@ async fn allocate_actor( old_runner_alloc_key.last_ping_ts, old_runner_alloc_key.runner_id, ), - rivet_key_data::converted::RunnerAllocIdxKeyData { + rivet_data::converted::RunnerAllocIdxKeyData { workflow_id: old_runner_alloc_key_data.workflow_id, remaining_slots: new_remaining_slots, total_slots: old_runner_alloc_key_data.total_slots, @@ -250,7 +259,7 @@ async fn allocate_actor( // want. If a runner reads from the queue while this is being inserted, one of the two txns will // retry and we ensure the actor does not end up in queue limbo. txs.write( - &keys::datacenter::PendingActorByRunnerNameSelectorKey::new( + &keys::ns::PendingActorByRunnerNameSelectorKey::new( namespace_id, input.runner_name_selector.clone(), pending_ts, @@ -299,7 +308,7 @@ pub async fn set_not_connectable(ctx: &ActivityCtx, input: &SetNotConnectableInp Ok(()) }) - .custom_instrument(tracing::info_span!("actor_deallocate_tx")) + .custom_instrument(tracing::info_span!("actor_set_not_connectable_tx")) .await?; state.connectable_ts = None; @@ -318,11 +327,13 @@ pub async fn deallocate(ctx: &ActivityCtx, input: &DeallocateInput) -> Result<() let runner_name_selector = &state.runner_name_selector; let namespace_id = state.namespace_id; let runner_id = state.runner_id; + let ns_runner_kind = &state.ns_runner_kind; ctx.udb()? .run(|tx, _mc| async move { - let connectable_key = keys::actor::ConnectableKey::new(input.actor_id); - tx.clear(&keys::subspace().pack(&connectable_key)); + let txs = tx.subspace(keys::subspace()); + + txs.delete(&keys::actor::ConnectableKey::new(input.actor_id)); if let Some(runner_id) = runner_id { destroy::clear_slot( @@ -330,9 +341,19 @@ pub async fn deallocate(ctx: &ActivityCtx, input: &DeallocateInput) -> Result<() namespace_id, runner_name_selector, runner_id, + ns_runner_kind, &tx, ) .await?; + } else if let RunnerKind::Outbound { .. } = ns_runner_kind { + txs.atomic_op( + &keys::ns::OutboundDesiredSlotsKey::new( + namespace_id, + runner_name_selector.clone(), + ), + &(-1i32).to_le_bytes(), + MutationType::Add, + ); } Ok(()) @@ -370,6 +391,10 @@ pub async fn spawn_actor( "failed to allocate (no availability), waiting for allocation", ); + ctx.msg(crate::messages::BumpOutboundAutoscaler {}) + .send() + .await?; + // If allocation fails, the allocate txn already inserted this actor into the queue. Now we wait for // an `Allocate` signal match ctx.listen::().await? { @@ -441,35 +466,9 @@ pub async fn reschedule_actor( ctx: &mut WorkflowCtx, input: &Input, state: &mut LifecycleState, - sleeping: bool, ) -> Result { tracing::debug!(actor_id=?input.actor_id, "rescheduling actor"); - // There shouldn't be an allocation if the actor is sleeping - if !sleeping { - ctx.activity(DeallocateInput { - actor_id: input.actor_id, - }) - .await?; - - // Allocate other pending actors from queue - let res = ctx - .activity(AllocatePendingActorsInput { - namespace_id: input.namespace_id, - name: input.runner_name_selector.clone(), - }) - .await?; - - // Dispatch pending allocs - for alloc in res.allocations { - ctx.signal(alloc.signal) - .to_workflow::() - .tag("actor_id", alloc.actor_id) - .send() - .await?; - } - } - let next_generation = state.generation + 1; // Waits for the actor to be ready (or destroyed) and automatically retries if failed to allocate. @@ -563,7 +562,7 @@ pub async fn clear_pending_allocation( .udb()? .run(|tx, _mc| async move { let pending_alloc_key = - keys::subspace().pack(&keys::datacenter::PendingActorByRunnerNameSelectorKey::new( + keys::subspace().pack(&keys::ns::PendingActorByRunnerNameSelectorKey::new( input.namespace_id, input.runner_name_selector.clone(), input.pending_allocation_ts, diff --git a/packages/services/pegboard/src/workflows/actor/setup.rs b/packages/services/pegboard/src/workflows/actor/setup.rs index 9d55ff596f..421b7228fc 100644 --- a/packages/services/pegboard/src/workflows/actor/setup.rs +++ b/packages/services/pegboard/src/workflows/actor/setup.rs @@ -1,5 +1,6 @@ use gas::prelude::*; -use rivet_key_data::converted::ActorNameKeyData; +use namespace::types::RunnerKind; +use rivet_data::converted::ActorNameKeyData; use rivet_types::actors::CrashPolicy; use udb_util::{SERIALIZABLE, TxnExt}; @@ -17,20 +18,25 @@ pub struct ValidateInput { pub input: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidateOutput { + pub ns_runner_kind: RunnerKind, +} + #[activity(Validate)] pub async fn validate( ctx: &ActivityCtx, input: &ValidateInput, -) -> Result> { +) -> Result> { let ns_res = ctx .op(namespace::ops::get_global::Input { - namespace_id: input.namespace_id, + namespace_ids: vec![input.namespace_id], }) .await?; - if ns_res.is_none() { + let Some(ns) = ns_res.into_iter().next() else { return Ok(Err(errors::Actor::NamespaceNotFound)); - } + }; if input .input @@ -55,7 +61,9 @@ pub async fn validate( } } - Ok(Ok(())) + Ok(Ok(ValidateOutput { + ns_runner_kind: ns.runner_kind, + })) } #[derive(Debug, Clone, Serialize, Deserialize, Hash)] @@ -67,6 +75,7 @@ pub struct InitStateAndUdbInput { pub runner_name_selector: String, pub crash_policy: CrashPolicy, pub create_ts: i64, + pub ns_runner_kind: RunnerKind, } #[activity(InitStateAndFdb)] @@ -80,6 +89,7 @@ pub async fn insert_state_and_fdb(ctx: &ActivityCtx, input: &InitStateAndUdbInpu input.runner_name_selector.clone(), input.crash_policy, input.create_ts, + input.ns_runner_kind.clone(), )); ctx.udb()? diff --git a/packages/services/pegboard/src/workflows/runner.rs b/packages/services/pegboard/src/workflows/runner.rs index ff7b5a64cc..e5d64f17d6 100644 --- a/packages/services/pegboard/src/workflows/runner.rs +++ b/packages/services/pegboard/src/workflows/runner.rs @@ -1,6 +1,6 @@ use futures_util::{FutureExt, StreamExt, TryStreamExt}; use gas::prelude::*; -use rivet_key_data::{ +use rivet_data::{ converted::{ActorNameKeyData, MetadataKeyData, RunnerByKeyKeyData}, generated::pegboard_runner_address_v1::Data as AddressKeyData, }; @@ -639,7 +639,7 @@ async fn insert_fdb(ctx: &ActivityCtx, input: &InsertFdbInput) -> Result<()> { // Insert into index (same as the `update_alloc_idx` op with `AddIdx`) txs.write( - &keys::datacenter::RunnerAllocIdxKey::new( + &keys::ns::RunnerAllocIdxKey::new( input.namespace_id, input.name.clone(), input.version, @@ -647,7 +647,7 @@ async fn insert_fdb(ctx: &ActivityCtx, input: &InsertFdbInput) -> Result<()> { last_ping_ts, input.runner_id, ), - rivet_key_data::converted::RunnerAllocIdxKeyData { + rivet_data::converted::RunnerAllocIdxKeyData { workflow_id: ctx.workflow_id(), remaining_slots, total_slots: input.total_slots, @@ -998,12 +998,11 @@ pub(crate) async fn allocate_pending_actors( let txs = tx.subspace(keys::subspace()); let mut results = Vec::new(); - let pending_actor_subspace = txs.subspace( - &keys::datacenter::PendingActorByRunnerNameSelectorKey::subspace( + let pending_actor_subspace = + txs.subspace(&keys::ns::PendingActorByRunnerNameSelectorKey::subspace( input.namespace_id, input.name.clone(), - ), - ); + )); let mut queue_stream = txs.get_ranges_keyvalues( udb::RangeOption { mode: StreamingMode::Iterator, @@ -1021,15 +1020,12 @@ pub(crate) async fn allocate_pending_actors( }; let (queue_key, generation) = - txs.read_entry::( - &queue_entry, - )?; + txs.read_entry::(&queue_entry)?; - let runner_alloc_subspace = - txs.subspace(&keys::datacenter::RunnerAllocIdxKey::subspace( - input.namespace_id, - input.name.clone(), - )); + let runner_alloc_subspace = txs.subspace(&keys::ns::RunnerAllocIdxKey::subspace( + input.namespace_id, + input.name.clone(), + )); let mut stream = txs.get_ranges_keyvalues( udb::RangeOption { @@ -1051,7 +1047,7 @@ pub(crate) async fn allocate_pending_actors( }; let (old_runner_alloc_key, old_runner_alloc_key_data) = - txs.read_entry::(&entry)?; + txs.read_entry::(&entry)?; if let Some(highest_version) = highest_version { // We have passed all of the runners with the highest version. This is reachable if @@ -1088,7 +1084,7 @@ pub(crate) async fn allocate_pending_actors( // Write new allocation key with 1 less slot txs.write( - &keys::datacenter::RunnerAllocIdxKey::new( + &keys::ns::RunnerAllocIdxKey::new( input.namespace_id, input.name.clone(), old_runner_alloc_key.version, @@ -1096,7 +1092,7 @@ pub(crate) async fn allocate_pending_actors( old_runner_alloc_key.last_ping_ts, old_runner_alloc_key.runner_id, ), - rivet_key_data::converted::RunnerAllocIdxKeyData { + rivet_data::converted::RunnerAllocIdxKeyData { workflow_id: old_runner_alloc_key_data.workflow_id, remaining_slots: new_remaining_slots, total_slots: old_runner_alloc_key_data.total_slots, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 91185e953a..d2358c0f78 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -741,6 +741,9 @@ importers: '@rivetkit/engine-runner-protocol': specifier: workspace:* version: link:../runner-protocol + hono: + specifier: ^4.0.0 + version: 4.8.12 ws: specifier: ^8.18.3 version: 8.18.3 diff --git a/sdks/rust/key-data/Cargo.toml b/sdks/rust/data/Cargo.toml similarity index 95% rename from sdks/rust/key-data/Cargo.toml rename to sdks/rust/data/Cargo.toml index 339d60c362..3282e73918 100644 --- a/sdks/rust/key-data/Cargo.toml +++ b/sdks/rust/data/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rivet-key-data" +name = "rivet-data" version.workspace = true authors.workspace = true license.workspace = true diff --git a/sdks/rust/key-data/build.rs b/sdks/rust/data/build.rs similarity index 99% rename from sdks/rust/key-data/build.rs rename to sdks/rust/data/build.rs index 898e7ea645..e6b18ab845 100644 --- a/sdks/rust/key-data/build.rs +++ b/sdks/rust/data/build.rs @@ -61,7 +61,7 @@ fn main() { .and_then(|p| p.parent()) .expect("Failed to find workspace root"); - let schema_dir = workspace_root.join("sdks").join("schemas").join("key-data"); + let schema_dir = workspace_root.join("sdks").join("schemas").join("data"); println!("cargo:rerun-if-changed={}", schema_dir.display()); diff --git a/sdks/rust/key-data/src/converted.rs b/sdks/rust/data/src/converted.rs similarity index 94% rename from sdks/rust/key-data/src/converted.rs rename to sdks/rust/data/src/converted.rs index 44e954bc84..859682e546 100644 --- a/sdks/rust/key-data/src/converted.rs +++ b/sdks/rust/data/src/converted.rs @@ -9,10 +9,10 @@ pub struct RunnerAllocIdxKeyData { pub total_slots: u32, } -impl TryFrom for RunnerAllocIdxKeyData { +impl TryFrom for RunnerAllocIdxKeyData { type Error = anyhow::Error; - fn try_from(value: pegboard_datacenter_runner_alloc_idx_v1::Data) -> Result { + fn try_from(value: pegboard_namespace_runner_alloc_idx_v1::Data) -> Result { Ok(RunnerAllocIdxKeyData { workflow_id: Id::from_slice(&value.workflow_id)?, remaining_slots: value.remaining_slots, @@ -21,11 +21,11 @@ impl TryFrom for RunnerAllocIdxKe } } -impl TryFrom for pegboard_datacenter_runner_alloc_idx_v1::Data { +impl TryFrom for pegboard_namespace_runner_alloc_idx_v1::Data { type Error = anyhow::Error; fn try_from(value: RunnerAllocIdxKeyData) -> Result { - Ok(pegboard_datacenter_runner_alloc_idx_v1::Data { + Ok(pegboard_namespace_runner_alloc_idx_v1::Data { workflow_id: value.workflow_id.as_bytes(), remaining_slots: value.remaining_slots, total_slots: value.total_slots, diff --git a/sdks/rust/key-data/src/generated.rs b/sdks/rust/data/src/generated.rs similarity index 100% rename from sdks/rust/key-data/src/generated.rs rename to sdks/rust/data/src/generated.rs diff --git a/sdks/rust/key-data/src/lib.rs b/sdks/rust/data/src/lib.rs similarity index 84% rename from sdks/rust/key-data/src/lib.rs rename to sdks/rust/data/src/lib.rs index 3292e22e83..a93ceaea1d 100644 --- a/sdks/rust/key-data/src/lib.rs +++ b/sdks/rust/data/src/lib.rs @@ -2,9 +2,9 @@ pub mod converted; pub mod generated; pub mod versioned; -pub const PEGBOARD_DATACENTER_RUNNER_ALLOC_IDX_VERSION: u16 = 1; pub const PEGBOARD_RUNNER_ADDRESS_VERSION: u16 = 1; pub const PEGBOARD_RUNNER_METADATA_VERSION: u16 = 1; pub const PEGBOARD_NAMESPACE_ACTOR_BY_KEY_VERSION: u16 = 1; +pub const PEGBOARD_NAMESPACE_RUNNER_ALLOC_IDX_VERSION: u16 = 1; pub const PEGBOARD_NAMESPACE_RUNNER_BY_KEY_VERSION: u16 = 1; pub const PEGBOARD_NAMESPACE_ACTOR_NAME_VERSION: u16 = 1; diff --git a/sdks/rust/key-data/src/versioned.rs b/sdks/rust/data/src/versioned.rs similarity index 83% rename from sdks/rust/key-data/src/versioned.rs rename to sdks/rust/data/src/versioned.rs index af709553f3..002363d361 100644 --- a/sdks/rust/key-data/src/versioned.rs +++ b/sdks/rust/data/src/versioned.rs @@ -4,13 +4,13 @@ use versioned_data_util::OwnedVersionedData; use crate::generated::*; pub enum RunnerAllocIdxKeyData { - V1(pegboard_datacenter_runner_alloc_idx_v1::Data), + V1(pegboard_namespace_runner_alloc_idx_v1::Data), } impl OwnedVersionedData for RunnerAllocIdxKeyData { - type Latest = pegboard_datacenter_runner_alloc_idx_v1::Data; + type Latest = pegboard_namespace_runner_alloc_idx_v1::Data; - fn latest(latest: pegboard_datacenter_runner_alloc_idx_v1::Data) -> Self { + fn latest(latest: pegboard_namespace_runner_alloc_idx_v1::Data) -> Self { RunnerAllocIdxKeyData::V1(latest) } @@ -206,3 +206,37 @@ impl OwnedVersionedData for ActorNameKeyData { } } } + +pub enum NamespaceRunnerKind { + V1(namespace_runner_kind_v1::Data), +} + +impl OwnedVersionedData for NamespaceRunnerKind { + type Latest = namespace_runner_kind_v1::Data; + + fn latest(latest: namespace_runner_kind_v1::Data) -> Self { + NamespaceRunnerKind::V1(latest) + } + + fn into_latest(self) -> Result { + #[allow(irrefutable_let_patterns)] + if let NamespaceRunnerKind::V1(data) = self { + Ok(data) + } else { + bail!("version not latest"); + } + } + + fn deserialize_version(payload: &[u8], version: u16) -> Result { + match version { + 1 => Ok(NamespaceRunnerKind::V1(serde_bare::from_slice(payload)?)), + _ => bail!("invalid version: {version}"), + } + } + + fn serialize_version(self, _version: u16) -> Result> { + match self { + NamespaceRunnerKind::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), + } + } +} diff --git a/sdks/schemas/data/namespace.runner_kind.v1.bare b/sdks/schemas/data/namespace.runner_kind.v1.bare new file mode 100644 index 0000000000..1cc263b54e --- /dev/null +++ b/sdks/schemas/data/namespace.runner_kind.v1.bare @@ -0,0 +1,14 @@ +type Outbound struct { + url: str + slots_per_runner: u32 + min_runners: u32 + max_runners: u32 + runners_margin: u32 +} + +type Custom void + +type Data union { + Outbound | + Custom +} diff --git a/sdks/schemas/key-data/pegboard.namespace.actor_by_key.v1.bare b/sdks/schemas/data/pegboard.namespace.actor_by_key.v1.bare similarity index 100% rename from sdks/schemas/key-data/pegboard.namespace.actor_by_key.v1.bare rename to sdks/schemas/data/pegboard.namespace.actor_by_key.v1.bare diff --git a/sdks/schemas/key-data/pegboard.namespace.actor_name.v1.bare b/sdks/schemas/data/pegboard.namespace.actor_name.v1.bare similarity index 100% rename from sdks/schemas/key-data/pegboard.namespace.actor_name.v1.bare rename to sdks/schemas/data/pegboard.namespace.actor_name.v1.bare diff --git a/sdks/schemas/key-data/pegboard.datacenter.runner_alloc_idx.v1.bare b/sdks/schemas/data/pegboard.namespace.runner_alloc_idx.v1.bare similarity index 100% rename from sdks/schemas/key-data/pegboard.datacenter.runner_alloc_idx.v1.bare rename to sdks/schemas/data/pegboard.namespace.runner_alloc_idx.v1.bare diff --git a/sdks/schemas/key-data/pegboard.namespace.runner_by_key.v1.bare b/sdks/schemas/data/pegboard.namespace.runner_by_key.v1.bare similarity index 100% rename from sdks/schemas/key-data/pegboard.namespace.runner_by_key.v1.bare rename to sdks/schemas/data/pegboard.namespace.runner_by_key.v1.bare diff --git a/sdks/schemas/key-data/pegboard.runner.address.v1.bare b/sdks/schemas/data/pegboard.runner.address.v1.bare similarity index 100% rename from sdks/schemas/key-data/pegboard.runner.address.v1.bare rename to sdks/schemas/data/pegboard.runner.address.v1.bare diff --git a/sdks/schemas/key-data/pegboard.runner.metadata.v1.bare b/sdks/schemas/data/pegboard.runner.metadata.v1.bare similarity index 100% rename from sdks/schemas/key-data/pegboard.runner.metadata.v1.bare rename to sdks/schemas/data/pegboard.runner.metadata.v1.bare diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 2623d63848..611f2bbcf0 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -35,6 +35,7 @@ export interface RunnerConfig { metadata?: Record; onConnected: () => void; onDisconnected: () => void; + onShutdown: () => void; fetch: (actorId: string, request: Request) => Promise; websocket?: (actorId: string, ws: any, request: Request) => Promise; onActorStart: ( @@ -362,9 +363,9 @@ export class Runner { //console.log("Tunnel shutdown completed"); } - if (exit) { - process.exit(0); - } + if (exit) process.exit(0); + + this.#config.onShutdown(); } // MARK: Networking diff --git a/sdks/typescript/test-runner/package.json b/sdks/typescript/test-runner/package.json index 7092ba4acf..ed441804ed 100644 --- a/sdks/typescript/test-runner/package.json +++ b/sdks/typescript/test-runner/package.json @@ -11,6 +11,7 @@ "@rivetkit/engine-runner": "workspace:*", "@hono/node-server": "^1.18.2", "@rivetkit/engine-runner-protocol": "workspace:*", + "hono": "^4.0.0", "ws": "^8.18.3" }, "devDependencies": { @@ -22,4 +23,4 @@ "typescript": "^5.3.3", "vitest": "^1.6.0" } -} +} \ No newline at end of file diff --git a/sdks/typescript/test-runner/src/main.ts b/sdks/typescript/test-runner/src/main.ts index 596dda5009..fbe681326d 100644 --- a/sdks/typescript/test-runner/src/main.ts +++ b/sdks/typescript/test-runner/src/main.ts @@ -2,6 +2,8 @@ import { Runner } from "@rivetkit/engine-runner"; import type { RunnerConfig, ActorConfig } from "@rivetkit/engine-runner"; import WebSocket from "ws"; import { serve } from "@hono/node-server"; +import { streamSSE } from "hono/streaming"; +import { Hono } from 'hono' const INTERNAL_SERVER_PORT = process.env.INTERNAL_SERVER_PORT ? Number(process.env.INTERNAL_SERVER_PORT) @@ -16,120 +18,150 @@ const RIVET_RUNNER_TOTAL_SLOTS = process.env.RIVET_RUNNER_TOTAL_SLOTS ? Number(process.env.RIVET_RUNNER_TOTAL_SLOTS) : 100; const RIVET_ENDPOINT = process.env.RIVET_ENDPOINT ?? "http://localhost:6420"; +const AUTOSTART = process.env.NO_AUTOSTART == undefined; let runnerStarted = Promise.withResolvers(); +let runnerStopped = Promise.withResolvers(); let websocketOpen = Promise.withResolvers(); let websocketClosed = Promise.withResolvers(); let runner: Runner | null = null; const actorWebSockets = new Map(); -// Start internal server -serve({ - fetch: async (request: Request) => { - const url = new URL(request.url); - if (url.pathname == "/wait-ready") { - await runnerStarted.promise; - return new Response(JSON.stringify(runner?.runnerId), { - status: 200, - }); - } else if (url.pathname == "/has-actor") { - let actorIdQuery = url.searchParams.get("actor"); - let generationQuery = url.searchParams.get("generation"); - let generation = generationQuery - ? Number(generationQuery) - : undefined; - - if (!actorIdQuery || !runner?.hasActor(actorIdQuery, generation)) { - return new Response(undefined, { status: 404 }); - } - } else if (url.pathname == "/shutdown") { - await runner?.shutdown(true); - } +// Create internal server +const app = new Hono(); + +app.get('/wait-ready', async (c) => { + await runnerStarted.promise; + return c.json(runner?.runnerId); +}); + +app.get('/has-actor', async (c) => { + let actorIdQuery = c.req.query('actor'); + let generationQuery = c.req.query('generation'); + let generation = generationQuery ? Number(generationQuery) : undefined; + + if (!actorIdQuery || !runner?.hasActor(actorIdQuery, generation)) { + return c.text('', 404); + } + return c.text('ok'); +}); + +app.get('/shutdown', async (c) => { + await runner?.shutdown(true); + return c.text('ok'); +}); + +app.get('/start', async (c) => { + return streamSSE(c, async (stream) => { + if (runner == null) runner = await startRunner(); + + stream.writeSSE({ data: runner.runnerId! }); + + await runnerStopped.promise; + }); +}); - return new Response("ok", { status: 200 }); - }, +app.get('*', (c) => c.text('ok')); + +serve({ + fetch: app.fetch, port: INTERNAL_SERVER_PORT, }); console.log(`Internal HTTP server listening on port ${INTERNAL_SERVER_PORT}`); -// Use objects to hold the current promise resolvers so callbacks always get the latest -const startedRef = { current: Promise.withResolvers() }; -const stoppedRef = { current: Promise.withResolvers() }; - -const config: RunnerConfig = { - version: RIVET_RUNNER_VERSION, - endpoint: RIVET_ENDPOINT, - namespace: RIVET_NAMESPACE, - runnerName: "test-runner", - runnerKey: RIVET_RUNNER_KEY, - totalSlots: RIVET_RUNNER_TOTAL_SLOTS, - prepopulateActorNames: {}, - onConnected: () => { - runnerStarted.resolve(undefined); - }, - onDisconnected: () => {}, - fetch: async (actorId: string, request: Request) => { - console.log( - `[TEST-RUNNER] Fetch called for actor ${actorId}, URL: ${request.url}`, - ); - const url = new URL(request.url); - if (url.pathname === "/ping") { - // Return the actor ID in response - const responseData = { - actorId, - status: "ok", - timestamp: Date.now(), - }; - console.log(`[TEST-RUNNER] Returning ping response:`, responseData); - return new Response(JSON.stringify(responseData), { - status: 200, - headers: { "Content-Type": "application/json" }, +if (AUTOSTART) runner = await startRunner(); + +async function startRunner(): Promise { + const config: RunnerConfig = { + version: RIVET_RUNNER_VERSION, + endpoint: RIVET_ENDPOINT, + namespace: RIVET_NAMESPACE, + runnerName: "test-runner", + runnerKey: RIVET_RUNNER_KEY, + totalSlots: RIVET_RUNNER_TOTAL_SLOTS, + prepopulateActorNames: {}, + onConnected: () => { + runnerStarted.resolve(undefined); + }, + onDisconnected: () => { }, + onShutdown: () => { + runnerStopped.resolve(undefined); + }, + fetch: async (actorId: string, request: Request) => { + console.log(`[TEST-RUNNER] Fetch called for actor ${actorId}, URL: ${request.url}`); + const url = new URL(request.url); + if (url.pathname === "/ping") { + // Return the actor ID in response + const responseData = { + actorId, + status: "ok", + timestamp: Date.now(), + }; + console.log(`[TEST-RUNNER] Returning ping response:`, responseData); + return new Response( + JSON.stringify(responseData), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ); + } + + return new Response("ok", { status: 200 }); + }, + onActorStart: async ( + _actorId: string, + _generation: number, + _config: ActorConfig, + ) => { + console.log( + `Actor ${_actorId} started (generation ${_generation})`, + ); + }, + onActorStop: async (_actorId: string, _generation: number) => { + console.log( + `Actor ${_actorId} stopped (generation ${_generation})`, + ); + }, + websocket: async ( + actorId: string, + ws: WebSocket, + request: Request, + ) => { + console.log(`WebSocket connected for actor ${actorId}`); + websocketOpen.resolve(undefined); + actorWebSockets.set(actorId, ws); + + // Echo server - send back any messages received + ws.addEventListener("message", (event) => { + const data = event.data; + console.log( + `WebSocket message from actor ${actorId}:`, + data, + ); + ws.send(`Echo: ${data}`); + }); + + ws.addEventListener("close", () => { + console.log(`WebSocket closed for actor ${actorId}`); + actorWebSockets.delete(actorId); + websocketClosed.resolve(undefined); + }); + + ws.addEventListener("error", (error) => { + console.error(`WebSocket error for actor ${actorId}:`, error); }); - } - - return new Response("ok", { status: 200 }); - }, - onActorStart: async ( - _actorId: string, - _generation: number, - _config: ActorConfig, - ) => { - console.log(`Actor ${_actorId} started (generation ${_generation})`); - startedRef.current.resolve(undefined); - }, - onActorStop: async (_actorId: string, _generation: number) => { - console.log(`Actor ${_actorId} stopped (generation ${_generation})`); - stoppedRef.current.resolve(undefined); - }, - websocket: async (actorId: string, ws: WebSocket, request: Request) => { - console.log(`WebSocket connected for actor ${actorId}`); - websocketOpen.resolve(undefined); - actorWebSockets.set(actorId, ws); - - // Echo server - send back any messages received - ws.addEventListener("message", (event) => { - const data = event.data; - console.log(`WebSocket message from actor ${actorId}:`, data); - ws.send(`Echo: ${data}`); - }); - - ws.addEventListener("close", () => { - console.log(`WebSocket closed for actor ${actorId}`); - actorWebSockets.delete(actorId); - websocketClosed.resolve(undefined); - }); - - ws.addEventListener("error", (error) => { - console.error(`WebSocket error for actor ${actorId}:`, error); - }); - }, -}; - -runner = new Runner(config); - -// Start runner -await runner.start(); - -// Wait for runner to be ready -console.log("Waiting runner start..."); -await runnerStarted.promise; + }, + }; + + runner = new Runner(config); + + // Start runner + await runner.start(); + + // Wait for runner to be ready + console.log("Waiting runner start..."); + await runnerStarted.promise; + + return runner; +} \ No newline at end of file