diff --git a/Cargo.lock b/Cargo.lock index 250392e348..21f444cce8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3259,6 +3259,7 @@ name = "pegboard" version = "25.7.1" dependencies = [ "anyhow", + "base64 0.22.1", "epoxy", "gasoline", "lazy_static", @@ -3277,6 +3278,7 @@ dependencies = [ "strum", "tracing", "universaldb", + "universalpubsub", "utoipa", "versioned-data-util", ] @@ -3315,7 +3317,6 @@ dependencies = [ "rand 0.8.5", "rivet-error", "rivet-guard-core", - "rivet-tunnel-protocol", "rivet-util", "thiserror 1.0.69", "tokio", @@ -3349,8 +3350,10 @@ dependencies = [ "rivet-runtime", "serde", "serde_json", + "tokio", "tokio-tungstenite", "tracing", + "universalpubsub", "url", "versioned-data-util", ] @@ -3394,7 +3397,6 @@ dependencies = [ "rivet-metrics", "rivet-pools", "rivet-runtime", - "rivet-tunnel-protocol", "rivet-util", "serde", "serde_json", @@ -4669,23 +4671,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "rivet-tunnel-protocol" -version = "25.7.1" -dependencies = [ - "anyhow", - "bare_gen", - "base64 0.22.1", - "indoc", - "prettyplease", - "rivet-util", - "serde", - "serde_bare", - "serde_json", - "syn 2.0.104", - "versioned-data-util", -] - [[package]] name = "rivet-types" version = "25.7.1" diff --git a/Cargo.toml b/Cargo.toml index e82e03a547..84d870d51e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] resolver = "2" -members = ["packages/common/api-builder","packages/common/api-types","packages/common/api-util","packages/common/cache/build","packages/common/cache/result","packages/common/clickhouse-inserter","packages/common/clickhouse-user-query","packages/common/config","packages/common/env","packages/common/error/core","packages/common/error/macros","packages/common/gasoline/core","packages/common/gasoline/macros","packages/common/logs","packages/common/metrics","packages/common/pools","packages/common/runtime","packages/common/service-manager","packages/common/telemetry","packages/common/test-deps","packages/common/test-deps-docker","packages/common/types","packages/common/universaldb","packages/common/universalpubsub","packages/common/util/core","packages/common/util/id","packages/common/versioned-data-util","packages/core/actor-kv","packages/core/api-peer","packages/core/api-public","packages/core/bootstrap","packages/core/dump-openapi","packages/core/guard/core","packages/core/guard/server","packages/core/pegboard-gateway","packages/core/pegboard-runner","packages/core/pegboard-serverless","packages/core/pegboard-tunnel","packages/core/workflow-worker","packages/infra/engine","packages/services/epoxy","packages/services/internal","packages/services/namespace","packages/services/pegboard","sdks/rust/api-full","sdks/rust/bare_gen","sdks/rust/data","sdks/rust/epoxy-protocol","sdks/rust/runner-protocol","sdks/rust/tunnel-protocol","sdks/rust/ups-protocol"] +members = ["packages/common/api-builder","packages/common/api-types","packages/common/api-util","packages/common/cache/build","packages/common/cache/result","packages/common/clickhouse-inserter","packages/common/clickhouse-user-query","packages/common/config","packages/common/env","packages/common/error/core","packages/common/error/macros","packages/common/gasoline/core","packages/common/gasoline/macros","packages/common/logs","packages/common/metrics","packages/common/pools","packages/common/runtime","packages/common/service-manager","packages/common/telemetry","packages/common/test-deps","packages/common/test-deps-docker","packages/common/types","packages/common/universaldb","packages/common/universalpubsub","packages/common/util/core","packages/common/util/id","packages/common/versioned-data-util","packages/core/actor-kv","packages/core/api-peer","packages/core/api-public","packages/core/bootstrap","packages/core/dump-openapi","packages/core/guard/core","packages/core/guard/server","packages/core/pegboard-gateway","packages/core/pegboard-runner","packages/core/pegboard-serverless","packages/core/pegboard-tunnel","packages/core/workflow-worker","packages/infra/engine","packages/services/epoxy","packages/services/internal","packages/services/namespace","packages/services/pegboard","sdks/rust/api-full","sdks/rust/bare_gen","sdks/rust/data","sdks/rust/epoxy-protocol","sdks/rust/runner-protocol","sdks/rust/ups-protocol"] [workspace.package] version = "25.7.1" @@ -401,9 +401,6 @@ path = "sdks/rust/epoxy-protocol" [workspace.dependencies.rivet-runner-protocol] path = "sdks/rust/runner-protocol" -[workspace.dependencies.rivet-tunnel-protocol] -path = "sdks/rust/tunnel-protocol" - [workspace.dependencies.rivet-ups-protocol] path = "sdks/rust/ups-protocol" diff --git a/packages/core/guard/server/src/routing/pegboard_gateway.rs b/packages/core/guard/server/src/routing/pegboard_gateway.rs index a7b944fa5e..519debc481 100644 --- a/packages/core/guard/server/src/routing/pegboard_gateway.rs +++ b/packages/core/guard/server/src/routing/pegboard_gateway.rs @@ -173,38 +173,11 @@ async fn find_actor( tracing::debug!(?actor_id, ?runner_id, "actor ready"); - // TODO: Remove round trip, return key from get_runner op above - // Get runner key, namespace_id, and runner_name from runner_id - let (runner_key, namespace_id, runner_name) = ctx - .udb()? - .run(|tx| async move { - let tx = tx.with_subspace(pegboard::keys::subspace()); - - let runner_key_key = pegboard::keys::runner::KeyKey::new(runner_id); - let namespace_id_key = pegboard::keys::runner::NamespaceIdKey::new(runner_id); - let runner_name_key = pegboard::keys::runner::NameKey::new(runner_id); - - let (runner_key, namespace_id, runner_name) = tokio::try_join!( - tx.read_opt(&runner_key_key, Serializable), - tx.read_opt(&namespace_id_key, Serializable), - tx.read_opt(&runner_name_key, Serializable), - )?; - - let runner_key = runner_key.context("runner key not found")?; - let namespace_id = namespace_id.context("runner namespace_id not found")?; - let runner_name = runner_name.context("runner name not found")?; - - Ok((runner_key, namespace_id, runner_name)) - }) - .await?; - // Return pegboard-gateway instance let gateway = pegboard_gateway::PegboardGateway::new( ctx.clone(), shared_state.pegboard_gateway.clone(), - namespace_id, - runner_name, - runner_key, + runner_id, actor_id, ); Ok(Some(RoutingOutput::CustomServe(std::sync::Arc::new( diff --git a/packages/core/pegboard-gateway/Cargo.toml b/packages/core/pegboard-gateway/Cargo.toml index bf4eba0740..b3d552739d 100644 --- a/packages/core/pegboard-gateway/Cargo.toml +++ b/packages/core/pegboard-gateway/Cargo.toml @@ -18,7 +18,6 @@ pegboard.workspace = true rand.workspace = true rivet-error.workspace = true rivet-guard-core.workspace = true -rivet-tunnel-protocol.workspace = true rivet-util.workspace = true thiserror.workspace = true tokio-tungstenite.workspace = true diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index da35fba436..f253d03a7b 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -37,27 +37,16 @@ const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2); pub struct PegboardGateway { ctx: StandaloneCtx, shared_state: SharedState, - namespace_id: Id, - runner_name: String, - runner_key: String, + runner_id: Id, actor_id: Id, } impl PegboardGateway { - pub fn new( - ctx: StandaloneCtx, - shared_state: SharedState, - namespace_id: Id, - runner_name: String, - runner_key: String, - actor_id: Id, - ) -> Self { + pub fn new(ctx: StandaloneCtx, shared_state: SharedState, runner_id: Id, actor_id: Id) -> Self { Self { ctx, shared_state, - namespace_id, - runner_name, - runner_key, + runner_id: Id, actor_id, } } @@ -151,12 +140,8 @@ impl PegboardGateway { .to_bytes(); // Build subject to publish to - let tunnel_subject = pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new( - self.namespace_id, - &self.runner_name, - &self.runner_key, - ) - .to_string(); + let tunnel_subject = + pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new(self.runner_id).to_string(); // Start listening for request responses let (request_id, mut msg_rx) = self @@ -246,12 +231,8 @@ impl PegboardGateway { } // Build subject to publish to - let tunnel_subject = pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new( - self.namespace_id, - &self.runner_name, - &self.runner_key, - ) - .to_string(); + let tunnel_subject = + pegboard::pubsub_subjects::TunnelRunnerReceiverSubject::new(self.runner_id).to_string(); // Start listening for WebSocket messages let (request_id, mut msg_rx) = self diff --git a/packages/core/pegboard-runner/Cargo.toml b/packages/core/pegboard-runner/Cargo.toml index fad664e3d8..bcca3f7e2c 100644 --- a/packages/core/pegboard-runner/Cargo.toml +++ b/packages/core/pegboard-runner/Cargo.toml @@ -25,10 +25,12 @@ rivet-runner-protocol.workspace = true rivet-runtime.workspace = true serde.workspace = true serde_json.workspace = true +tokio.workspace = true tokio-tungstenite.workspace = true tracing.workspace = true url.workspace = true versioned-data-util.workspace = true +universalpubsub.workspace = true pegboard.workspace = true pegboard-actor-kv.workspace = true diff --git a/packages/core/pegboard-runner/src/client_to_pubsub_task.rs b/packages/core/pegboard-runner/src/client_to_pubsub_task.rs new file mode 100644 index 0000000000..d1d1d97ef1 --- /dev/null +++ b/packages/core/pegboard-runner/src/client_to_pubsub_task.rs @@ -0,0 +1,377 @@ +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::{ + SinkExt, StreamExt, + stream::{SplitSink, SplitStream}, +}; +use gas::prelude::Id; +use gas::prelude::*; +use http_body_util::Full; +use hyper::upgrade::Upgraded; +use hyper::{Response, StatusCode}; +use hyper_tungstenite::tungstenite::Message as WsMessage; +use hyper_tungstenite::{HyperWebsocket, tungstenite::Message}; +use hyper_util::rt::TokioIo; +use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +use pegboard_actor_kv as kv; +use rivet_error::*; +use rivet_guard_core::{ + custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +}; +use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use serde_json::json; +use std::{ + collections::HashMap, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, + time::Duration, +}; +use tokio::sync::{Mutex, RwLock}; +use tokio_tungstenite::{ + WebSocketStream, + tungstenite::protocol::frame::{CloseFrame, coding::CloseCode}, +}; +use universalpubsub::{NextOutput, PublishOpts}; +use versioned_data_util::OwnedVersionedData as _; + +use crate::{ + conn::Conn, + utils::{self, WebSocketReceiver}, +}; + +pub async fn task(ctx: StandaloneCtx, conn: Arc, ws_rx: WebSocketReceiver) { + match task_inner(ctx, conn, ws_rx).await { + Ok(_) => {} + Err(err) => { + tracing::error!(?err, "client to pubsub errored"); + } + } +} + +async fn task_inner( + ctx: StandaloneCtx, + conn: Arc, + mut ws_rx: WebSocketReceiver, +) -> Result<()> { + tracing::info!("starting WebSocket to pubsub forwarding task"); + while let Some(msg) = ws_rx.next().await { + match msg { + Result::Ok(WsMessage::Binary(data)) => { + tracing::info!( + data_len = data.len(), + "received binary message from WebSocket" + ); + + // Parse message + let msg = + match versioned::ToServer::deserialize_version(&data, conn.protocol_version) + .and_then(|x| x.into_latest()) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to deserialize message"); + continue; + } + }; + + handle_message(&ctx, &conn, msg).await?; + } + Result::Ok(WsMessage::Close(_)) => { + tracing::info!(?conn.runner_id, "WebSocket closed"); + break; + } + Result::Ok(_) => { + // Ignore other message types + } + Err(e) => { + tracing::error!(?e, "WebSocket error"); + break; + } + } + } + tracing::info!("WebSocket to pubsub forwarding task ended"); + + Ok(()) +} + +async fn handle_message( + ctx: &StandaloneCtx, + conn: &Arc, + msg: protocol::ToServer, +) -> Result<()> { + match msg { + protocol::ToServer::ToServerPing(ping) => { + let rtt = util::timestamp::now().saturating_sub(ping.ts).try_into()?; + + conn.last_rtt.store(rtt, Ordering::Relaxed); + } + // Process KV request + protocol::ToServer::ToServerKvRequest(req) => { + let actor_id = match Id::parse(&req.actor_id) { + Ok(actor_id) => actor_id, + Err(err) => { + let res_msg = versioned::ToClient::latest( + protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { + request_id: req.request_id, + data: protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + message: err.to_string(), + }, + ), + }), + ); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + + return Ok(()); + } + }; + + let actors_res = ctx + .op(pegboard::ops::actor::get_runner::Input { + actor_ids: vec![actor_id], + }) + .await?; + let actor_belongs = actors_res + .actors + .first() + .map(|x| x.runner_id == conn.runner_id) + .unwrap_or_default(); + + // Verify actor belongs to this runner + if !actor_belongs { + let res_msg = versioned::ToClient::latest(protocol::ToClient::ToClientKvResponse( + protocol::ToClientKvResponse { + request_id: req.request_id, + data: protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + message: "given actor does not belong to runner".to_string(), + }, + ), + }, + )); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + + return Ok(()); + } + + // TODO: Add queue and bg thread for processing kv ops + // Run kv operation + match req.data { + protocol::KvRequestData::KvGetRequest(body) => { + let res = kv::get(&*ctx.udb()?, actor_id, body.keys).await; + + let res_msg = versioned::ToClient::latest( + protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok((keys, values, metadata)) => { + protocol::KvResponseData::KvGetResponse( + protocol::KvGetResponse { + keys, + values, + metadata, + }, + ) + } + Err(err) => protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ), + }, + }), + ); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + } + protocol::KvRequestData::KvListRequest(body) => { + let res = kv::list( + &*ctx.udb()?, + actor_id, + body.query, + body.reverse.unwrap_or_default(), + body.limit.map(TryInto::try_into).transpose()?, + ) + .await; + + let res_msg = versioned::ToClient::latest( + protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok((keys, values, metadata)) => { + protocol::KvResponseData::KvListResponse( + protocol::KvListResponse { + keys, + values, + metadata, + }, + ) + } + Err(err) => protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ), + }, + }), + ); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + } + protocol::KvRequestData::KvPutRequest(body) => { + let res = kv::put(&*ctx.udb()?, actor_id, body.keys, body.values).await; + + let res_msg = versioned::ToClient::latest( + protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => protocol::KvResponseData::KvPutResponse, + Err(err) => { + protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ) + } + }, + }), + ); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + } + protocol::KvRequestData::KvDeleteRequest(body) => { + let res = kv::delete(&*ctx.udb()?, actor_id, body.keys).await; + + let res_msg = versioned::ToClient::latest( + protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => protocol::KvResponseData::KvDeleteResponse, + Err(err) => protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ), + }, + }), + ); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + } + protocol::KvRequestData::KvDropRequest => { + let res = kv::delete_all(&*ctx.udb()?, actor_id).await; + + let res_msg = versioned::ToClient::latest( + protocol::ToClient::ToClientKvResponse(protocol::ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => protocol::KvResponseData::KvDropResponse, + Err(err) => protocol::KvResponseData::KvErrorResponse( + protocol::KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }, + ), + }, + }), + ); + + let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; + conn.ws_tx + .lock() + .await + .send(Message::Binary(res_msg_serialized.into())) + .await?; + } + } + } + protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { + handle_tunnel_message(&ctx, &conn, tunnel_msg).await?; + } + // Forward to runner wf + protocol::ToServer::ToServerInit(_) + | protocol::ToServer::ToServerEvents(_) + | protocol::ToServer::ToServerAckCommands(_) + | protocol::ToServer::ToServerStopping => { + ctx.signal(pegboard::workflows::runner::Forward { + inner: protocol::ToServer::try_from(msg)?, + }) + .to_workflow_id(conn.workflow_id) + .send() + .await?; + } + } + + Ok(()) +} + +async fn handle_tunnel_message( + ctx: &StandaloneCtx, + conn: &Arc, + msg: protocol::ToServerTunnelMessage, +) -> Result<()> { + // Determine reply to subject + let request_id = msg.request_id; + let gateway_reply_to = { + let active_requests = conn.tunnel_active_requests.lock().await; + if let Some(req) = active_requests.get(&request_id) { + req.gateway_reply_to.clone() + } else { + tracing::warn!("no active request for tunnel message, may have timed out"); + return Ok(()); + } + }; + + // Remove active request entries when terminal + if utils::is_to_server_tunnel_message_kind_request_close(&msg.message_kind) { + let mut active_requests = conn.tunnel_active_requests.lock().await; + active_requests.remove(&request_id); + } + + // Publish message to UPS + let msg_serialized = versioned::ToGateway::latest(protocol::ToGateway { message: msg }) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + ctx.ups()? + .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) + .await?; + + Ok(()) +} diff --git a/packages/core/pegboard-runner/src/conn.rs b/packages/core/pegboard-runner/src/conn.rs new file mode 100644 index 0000000000..2497f4391d --- /dev/null +++ b/packages/core/pegboard-runner/src/conn.rs @@ -0,0 +1,206 @@ +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::{ + SinkExt, StreamExt, + stream::{SplitSink, SplitStream}, +}; +use gas::prelude::Id; +use gas::prelude::*; +use http_body_util::Full; +use hyper::upgrade::Upgraded; +use hyper::{Response, StatusCode}; +use hyper_tungstenite::{HyperWebsocket, tungstenite::Message}; +use hyper_util::rt::TokioIo; +use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +use pegboard_actor_kv as kv; +use rivet_error::*; +use rivet_guard_core::{ + custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +}; +use rivet_runner_protocol as protocol; +use rivet_runner_protocol::*; +use serde_json::json; +use std::{ + collections::HashMap, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, + time::Duration, +}; +use tokio::sync::{Mutex, RwLock}; +use tokio_tungstenite::{ + WebSocketStream, + tungstenite::protocol::frame::{CloseFrame, coding::CloseCode}, +}; +use universalpubsub::NextOutput; + +use crate::utils::WebSocketSender; + +pub struct TunnelActiveRequest { + /// Subject to send replies to. + pub gateway_reply_to: String, +} + +pub struct Conn { + pub runner_id: Id, + pub workflow_id: Id, + pub protocol_version: u16, + pub ws_tx: Mutex, + + // tx: Arc< + // Mutex< + // Box< + // dyn futures_util::Sink + // + Send + // + Unpin, + // >, + // >, + // >, + pub last_rtt: AtomicU32, + + /// Active HTTP & WebSocket requests. They are separate but use the same mechanism to + /// maintain state. + pub tunnel_active_requests: Mutex>, +} + +impl Conn { + pub fn new() -> Self { + todo!() + } +} + +// #[tracing::instrument(skip_all)] +// async fn build_connection( +// ctx: &StandaloneCtx, +// tx: &mut Option>, +// rx: &mut futures_util::stream::SplitStream, +// UrlData { +// protocol_version, +// namespace, +// runner_key, +// }: UrlData, +// ) -> Result<(Id, Arc)> { +// let namespace = ctx +// .op(namespace::ops::resolve_for_name_global::Input { name: namespace }) +// .await? +// .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; +// +// tracing::debug!("new runner connection"); +// +// // Receive init packet +// let (runner_id, workflow_id) = if let Some(msg) = +// tokio::time::timeout(Duration::from_secs(5), rx.next()) +// .await +// .map_err(|_| WsError::TimedOutWaitingForInit.build())? +// { +// let buf = match msg? { +// Message::Binary(buf) => buf, +// Message::Close(_) => return Err(WsError::ConnectionClosed.build()), +// msg => { +// tracing::debug!(?msg, "invalid initial message"); +// return Err(WsError::InvalidInitialPacket("must be a binary blob").build()); +// } +// }; +// +// let packet = versioned::ToServer::deserialize(&buf, protocol_version) +// .map_err(|err| WsError::InvalidPacket(err.to_string()).build())?; +// +// let (runner_id, workflow_id) = +// if let protocol::ToServer::ToServerInit(protocol::ToServerInit { +// name, +// version, +// total_slots, +// .. +// }) = &packet +// { +// // Look up existing runner by key +// let existing_runner = ctx +// .op(pegboard::ops::runner::get_by_key::Input { +// namespace_id: namespace.namespace_id, +// name: name.clone(), +// key: runner_key.clone(), +// }) +// .await?; +// +// let runner_id = if let Some(runner) = existing_runner.runner { +// // IMPORTANT: Before we spawn/get the workflow, we try to update the runner's last ping ts. +// // This ensures if the workflow is currently checking for expiry that it will not expire +// // (because we are about to send signals to it) and if it is already expired (but not +// // completed) we can choose a new runner id. +// let update_ping_res = ctx +// .op(pegboard::ops::runner::update_alloc_idx::Input { +// runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { +// runner_id: runner.runner_id, +// action: Action::UpdatePing { rtt: 0 }, +// }], +// }) +// .await?; +// +// if update_ping_res +// .notifications +// .into_iter() +// .next() +// .map(|notif| matches!(notif.eligibility, RunnerEligibility::Expired)) +// .unwrap_or_default() +// { +// // Runner expired, create a new one +// Id::new_v1(ctx.config().dc_label()) +// } else { +// // Use existing runner +// runner.runner_id +// } +// } else { +// // No existing runner for this key, create a new one +// Id::new_v1(ctx.config().dc_label()) +// }; +// +// // Spawn a new runner workflow if one doesn't already exist +// let workflow_id = ctx +// .workflow(pegboard::workflows::runner::Input { +// runner_id, +// namespace_id: namespace.namespace_id, +// name: name.clone(), +// key: runner_key.clone(), +// version: version.clone(), +// total_slots: *total_slots, +// }) +// .tag("runner_id", runner_id) +// .unique() +// .dispatch() +// .await?; +// +// (runner_id, workflow_id) +// } else { +// tracing::debug!(?packet, "invalid initial packet"); +// return Err(WsError::InvalidInitialPacket("must be `ToServer::Init`").build()); +// }; +// +// // Forward to runner wf +// ctx.signal(pegboard::workflows::runner::Forward { inner: packet }) +// .to_workflow_id(workflow_id) +// .send() +// .await?; +// +// (runner_id, workflow_id) +// } else { +// return Err(WsError::ConnectionClosed.build()); +// }; +// +// let tx = tx.take().context("should exist")?; +// +// Ok(( +// runner_id, +// Arc::new(Connection { +// workflow_id, +// protocol_version, +// tx: Arc::new(Mutex::new(Box::new(tx) +// as Box< +// dyn futures_util::Sink +// + Send +// + Unpin, +// >)), +// last_rtt: AtomicU32::new(0), +// }), +// )) +// } diff --git a/packages/core/pegboard-runner/src/lib.rs b/packages/core/pegboard-runner/src/lib.rs index 6d8c37bc21..58bf90114d 100644 --- a/packages/core/pegboard-runner/src/lib.rs +++ b/packages/core/pegboard-runner/src/lib.rs @@ -17,6 +17,7 @@ use rivet_error::*; use rivet_guard_core::{ custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, }; +use rivet_runner_protocol as protocol; use rivet_runner_protocol::*; use serde_json::json; use std::{ @@ -32,8 +33,15 @@ use tokio_tungstenite::{ WebSocketStream, tungstenite::protocol::frame::{CloseFrame, coding::CloseCode}, }; -type HyperWebSocketStream = WebSocketStream>; -use versioned_data_util::OwnedVersionedData; +use universalpubsub::NextOutput; + +use crate::conn::Conn; + +mod client_to_pubsub_task; +mod conn; +mod ping_task; +mod pubsub_to_client_task; +mod utils; const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3); @@ -73,48 +81,13 @@ enum WsError { InvalidUrl(String), } -struct Connection { - workflow_id: Id, - protocol_version: u16, - tx: Arc< - Mutex< - Box< - dyn futures_util::Sink - + Send - + Unpin, - >, - >, - >, - last_rtt: AtomicU32, -} - -type Connections = HashMap>; - pub struct PegboardRunnerWsCustomServe { ctx: StandaloneCtx, - conns: Arc>, } impl PegboardRunnerWsCustomServe { pub fn new(ctx: StandaloneCtx) -> Self { - let conns = Arc::new(RwLock::new(HashMap::new())); - let service = Self { - ctx: ctx.clone(), - conns: conns.clone(), - }; - - // Start background threads - let msg_ctx = ctx.clone(); - let msg_conns = conns.clone(); - tokio::spawn(async move { - msg_thread(&msg_ctx, msg_conns).await; - }); - - let ping_ctx = ctx.clone(); - let ping_conns = conns.clone(); - tokio::spawn(async move { - update_ping_thread(&ping_ctx, ping_conns).await; - }); + let service = Self { ctx: ctx.clone() }; service } @@ -145,14 +118,27 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { _headers: &hyper::HeaderMap, path: &str, _request_context: &mut RequestContext, - ) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)> { + ) -> Result<(), (HyperWebsocket, anyhow::Error)> { + // TODO: Spawn ping thread + // TODO: Spawn message thread + // TODO: Create conn + + // Get UPS + let ups = match self.ctx.ups() { + Ok(x) => x, + Err(err) => { + tracing::warn!(?err, "could not get ups"); + return Err((client_ws, err)); + } + }; + // Parse URL to extract parameters - let url = match url::Url::parse(&format!("ws://placeholder{path}")) { + let url = match url::Url::parse(&format!("ws://placeholder/{path}")) { Result::Ok(u) => u, Result::Err(e) => return Err((client_ws, e.into())), }; - let url_data = match parse_url_from_url(url) { + let url_data = match utils::parse_url_from_url(url) { Result::Ok(x) => x, Result::Err(err) => { tracing::warn!(?err, "could not parse runner connection url"); @@ -160,693 +146,83 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { } }; + tracing::info!(?path, "tunnel ws connection established"); + + // TODO: Generate runner ID + let runner_id: Id = todo!(); + + // Subscribe to pubsub topic for this runner before accepting the client websocket so + // that failures can be retried by the proxy. + let topic = pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string(); + tracing::info!(%topic, "subscribing to runner receiver topic"); + let mut sub = match ups.subscribe(&topic).await { + Result::Ok(s) => s, + Err(e) => return Err((client_ws, e.into())), + }; + // Accept WS let ws_stream = match client_ws.await { Result::Ok(ws) => ws, Err(e) => { // Handshake already in progress; cannot retry safely here tracing::error!(error=?e, "client websocket await failed"); - return std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()); - } - }; - - self.handle_connection(ws_stream, url_data).await; - - std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()) - } -} - -impl PegboardRunnerWsCustomServe { - #[tracing::instrument(skip_all)] - async fn handle_connection(&self, ws_stream: HyperWebSocketStream, url_data: UrlData) { - let ctx = self.ctx.clone(); - let conns = self.conns.clone(); - - tokio::spawn(async move { - let (tx, mut rx) = ws_stream.split(); - let mut tx = Some(tx); - - let (runner_id, conn) = match build_connection(&ctx, &mut tx, &mut rx, url_data).await { - Ok(res) => res, - Err(err) => { - tracing::warn!(?err, "failed to build connection"); - - if let Some(mut tx) = tx { - let close_frame = err_to_close_frame(err); - - if let Err(err) = tx.send(Message::Close(Some(close_frame))).await { - tracing::error!(?err, "failed closing socket"); - } - } - - return; - } - }; - - // Store connection - { - let mut conns = conns.write().await; - if let Some(old_conn) = conns.insert(runner_id, conn.clone()) { - tracing::warn!( - ?runner_id, - "runner already connected, closing old connection" - ); - - let close_frame = err_to_close_frame(WsError::NewRunnerConnected.build()); - let mut tx = old_conn.tx.lock().await; - - if let Err(err) = tx.send(Message::Close(Some(close_frame))).await { - tracing::error!(?runner_id, ?err, "failed closing old connection"); - } - } - } - - let err = if let Err(err) = handle_messages(&ctx, &mut rx, runner_id, &conn).await { - tracing::warn!(?runner_id, ?err, "failed processing runner messages"); - - err - } else { - tracing::info!(?runner_id, "runner connection closed"); - - WsError::ConnectionClosed.build() - }; - - // Clean up - { - conns.write().await.remove(&runner_id); - } - - // Make runner immediately ineligible when it disconnects - if let Err(err) = ctx - .op(pegboard::ops::runner::update_alloc_idx::Input { - runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { - runner_id, - action: Action::ClearIdx, - }], - }) - .await - { - tracing::error!(?runner_id, ?err, "failed evicting runner from alloc idx"); - } - - let close_frame = err_to_close_frame(err); - let mut tx = conn.tx.lock().await; - if let Err(err) = tx.send(Message::Close(Some(close_frame))).await { - tracing::error!(?runner_id, ?err, "failed closing socket"); - } - }); - } -} - -#[tracing::instrument(skip_all)] -async fn build_connection( - ctx: &StandaloneCtx, - tx: &mut Option>, - rx: &mut futures_util::stream::SplitStream, - UrlData { - protocol_version, - namespace, - runner_key, - }: UrlData, -) -> Result<(Id, Arc)> { - let namespace = ctx - .op(namespace::ops::resolve_for_name_global::Input { name: namespace }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; - - tracing::debug!("new runner connection"); - - // Receive init packet - let (runner_id, workflow_id) = if let Some(msg) = - tokio::time::timeout(Duration::from_secs(5), rx.next()) - .await - .map_err(|_| WsError::TimedOutWaitingForInit.build())? - { - let buf = match msg? { - Message::Binary(buf) => buf, - Message::Close(_) => return Err(WsError::ConnectionClosed.build()), - msg => { - tracing::debug!(?msg, "invalid initial message"); - return Err(WsError::InvalidInitialPacket("must be a binary blob").build()); + return Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()); } }; - - let packet = versioned::ToServer::deserialize(&buf, protocol_version) - .map_err(|err| WsError::InvalidPacket(err.to_string()).build())? - .try_into() - .map_err(|err: anyhow::Error| WsError::InvalidPacket(err.to_string()).build())?; - - let (runner_id, workflow_id) = if let protocol::ToServer::Init { - name, - version, - total_slots, - .. - } = &packet - { - // Look up existing runner by key - let existing_runner = ctx - .op(pegboard::ops::runner::get_by_key::Input { - namespace_id: namespace.namespace_id, - name: name.clone(), - key: runner_key.clone(), - }) - .await?; - - let runner_id = if let Some(runner) = existing_runner.runner { - // IMPORTANT: Before we spawn/get the workflow, we try to update the runner's last ping ts. - // This ensures if the workflow is currently checking for expiry that it will not expire - // (because we are about to send signals to it) and if it is already expired (but not - // completed) we can choose a new runner id. - let update_ping_res = ctx - .op(pegboard::ops::runner::update_alloc_idx::Input { - runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { - runner_id: runner.runner_id, - action: Action::UpdatePing { rtt: 0 }, - }], - }) - .await?; - - if update_ping_res - .notifications - .into_iter() - .next() - .map(|notif| matches!(notif.eligibility, RunnerEligibility::Expired)) - .unwrap_or_default() - { - // Runner expired, create a new one - Id::new_v1(ctx.config().dc_label()) - } else { - // Use existing runner - runner.runner_id - } - } else { - // No existing runner for this key, create a new one - Id::new_v1(ctx.config().dc_label()) - }; - - // Spawn a new runner workflow if one doesn't already exist - let workflow_id = ctx - .workflow(pegboard::workflows::runner::Input { - runner_id, - namespace_id: namespace.namespace_id, - name: name.clone(), - key: runner_key.clone(), - version: version.clone(), - total_slots: *total_slots, - }) - .tag("runner_id", runner_id) - .unique() - .dispatch() - .await?; - - (runner_id, workflow_id) - } else { - tracing::debug!(?packet, "invalid initial packet"); - return Err(WsError::InvalidInitialPacket("must be `ToServer::Init`").build()); - }; - - // Forward to runner wf - ctx.signal(packet) - .to_workflow_id(workflow_id) - .send() - .await?; - - (runner_id, workflow_id) - } else { - return Err(WsError::ConnectionClosed.build()); - }; - - let tx = tx.take().context("should exist")?; - - Ok(( - runner_id, - Arc::new(Connection { - workflow_id, - protocol_version, - tx: Arc::new(Mutex::new(Box::new(tx) - as Box< - dyn futures_util::Sink - + Send - + Unpin, - >)), - last_rtt: AtomicU32::new(0), - }), - )) -} - -async fn handle_messages( - ctx: &StandaloneCtx, - rx: &mut futures_util::stream::SplitStream, - runner_id: Id, - conn: &Connection, -) -> Result<()> { - // Receive messages from socket - while let Some(msg) = rx.next().await { - let buf = match msg? { - Message::Binary(buf) => buf, - Message::Ping(_) => continue, - Message::Close(_) => bail!("socket closed {}", runner_id), - msg => { - tracing::warn!(?runner_id, ?msg, "unexpected message"); - continue; - } - }; - - let packet = versioned::ToServer::deserialize(&buf, conn.protocol_version)?; - - match packet { - ToServer::ToServerPing(ping) => { - let rtt = util::timestamp::now().saturating_sub(ping.ts).try_into()?; - - conn.last_rtt.store(rtt, Ordering::Relaxed); - } - // Process KV request - ToServer::ToServerKvRequest(req) => { - let actor_id = match Id::parse(&req.actor_id) { - Ok(actor_id) => actor_id, - Err(err) => { - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: KvResponseData::KvErrorResponse(KvErrorResponse { - message: err.to_string(), - }), - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - - continue; - } - }; - - let actors_res = ctx - .op(pegboard::ops::actor::get_runner::Input { - actor_ids: vec![actor_id], - }) - .await?; - let actor_belongs = actors_res - .actors - .first() - .map(|x| x.runner_id == runner_id) - .unwrap_or_default(); - - // Verify actor belongs to this runner - if !actor_belongs { - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: KvResponseData::KvErrorResponse(KvErrorResponse { - message: "given actor does not belong to runner".to_string(), - }), - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - - continue; - } - - // TODO: Add queue and bg thread for processing kv ops - // Run kv operation - match req.data { - KvRequestData::KvGetRequest(body) => { - let res = kv::get(&*ctx.udb()?, actor_id, body.keys).await; - - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok((keys, values, metadata)) => { - KvResponseData::KvGetResponse(KvGetResponse { - keys, - values, - metadata, - }) - } - Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }), - }, - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - } - KvRequestData::KvListRequest(body) => { - let res = kv::list( - &*ctx.udb()?, - actor_id, - body.query, - body.reverse.unwrap_or_default(), - body.limit.map(TryInto::try_into).transpose()?, - ) - .await; - - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok((keys, values, metadata)) => { - KvResponseData::KvListResponse(KvListResponse { - keys, - values, - metadata, - }) - } - Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }), - }, - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - } - KvRequestData::KvPutRequest(body) => { - let res = kv::put(&*ctx.udb()?, actor_id, body.keys, body.values).await; - - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok(()) => KvResponseData::KvPutResponse, - Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }), - }, - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - } - KvRequestData::KvDeleteRequest(body) => { - let res = kv::delete(&*ctx.udb()?, actor_id, body.keys).await; - - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok(()) => KvResponseData::KvDeleteResponse, - Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }), - }, - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - } - KvRequestData::KvDropRequest => { - let res = kv::delete_all(&*ctx.udb()?, actor_id).await; - - let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( - ToClientKvResponse { - request_id: req.request_id, - data: match res { - Ok(()) => KvResponseData::KvDropResponse, - Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { - // TODO: Don't return actual error? - message: err.to_string(), - }), - }, - }, - )); - - let buf = packet.serialize(conn.protocol_version)?; - conn.tx - .lock() - .await - .send(Message::Binary(buf.into())) - .await?; - } - } - } - // Forward to runner wf - _ => { - ctx.signal(protocol::ToServer::try_from(packet)?) - .to_workflow_id(conn.workflow_id) - .send() - .await?; - } - } - } - - bail!("stream closed {runner_id}"); -} - -#[tracing::instrument(skip_all)] -async fn update_ping_thread(ctx: &StandaloneCtx, conns: Arc>) { - loop { - match update_ping_thread_inner(ctx, conns.clone()).await { - Ok(_) => { - tracing::warn!("update ping thread thread exited early"); + let (ws_tx, ws_rx) = ws_stream.split(); + + // Create connection + let conn = Arc::new(Conn::new()); + + // Forward pubsub -> WebSocket + let pubsub_to_client = tokio::spawn(pubsub_to_client_task::task( + self.ctx.clone(), + conn.clone(), + sub, + )); + + // Forward WebSocket -> pubsub + let client_to_pubsub = tokio::spawn(client_to_pubsub_task::task( + self.ctx.clone(), + conn.clone(), + ws_rx, + )); + + // Wait for either task to complete + tokio::select! { + _ = pubsub_to_client => { + tracing::info!("pubsub to WebSocket task completed"); } - Err(err) => { - tracing::error!(?err, "update ping thread error"); + _ = client_to_pubsub => { + tracing::info!("WebSocket to pubsub task completed"); } } - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - } -} - -/// Updates the ping of all runners requesting a ping update at once. -#[tracing::instrument(skip_all)] -async fn update_ping_thread_inner( - ctx: &StandaloneCtx, - conns: Arc>, -) -> Result<()> { - loop { - tokio::time::sleep(UPDATE_PING_INTERVAL).await; - - let runners = { - let mut conns = conns.write().await; - - // Select all runners that required a ping update - conns - .iter_mut() - .map(|(runner_id, conn)| { - ( - *runner_id, - conn.workflow_id, - conn.last_rtt.load(Ordering::Relaxed), - ) - }) - .collect::>() - }; - - if runners.is_empty() { - continue; - } - - let mut runners2 = Vec::new(); - - // TODO: Parallelize - // Filter out dead wfs - for (runner_id, workflow_id, rtt) in runners { - let Some(wf) = ctx - .workflow::(workflow_id) - .get() - .await? - else { - tracing::error!(?runner_id, "workflow does not exist"); - continue; - }; - - // Only update ping if the workflow is not dead - if wf.has_wake_condition { - runners2.push(pegboard::ops::runner::update_alloc_idx::Runner { + // Make runner immediately ineligible when it disconnects + if let Err(err) = self + .ctx + .op(pegboard::ops::runner::update_alloc_idx::Input { + runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { runner_id, - action: Action::UpdatePing { rtt }, - }); - } - } - - if runners2.is_empty() { - continue; - } - - let res = ctx - .op(pegboard::ops::runner::update_alloc_idx::Input { runners: runners2 }) - .await?; - - for notif in res.notifications { - if let RunnerEligibility::ReEligible = notif.eligibility { - tracing::debug!(runner_id=?notif.runner_id, "runner has become eligible again"); - - ctx.signal(pegboard::workflows::runner::CheckQueue {}) - .to_workflow_id(notif.workflow_id) - .send() - .await?; - } - } - } -} - -#[tracing::instrument(skip_all)] -async fn msg_thread(ctx: &StandaloneCtx, conns: Arc>) { - loop { - match msg_thread_inner(ctx, conns.clone()).await { - Ok(_) => { - tracing::warn!("msg thread exited early"); - } - Err(err) => { - tracing::error!(?err, "msg thread error"); - } + action: Action::ClearIdx, + }], + }) + .await + { + tracing::error!(?runner_id, ?err, "failed evicting runner from alloc idx"); } - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - } -} - -#[tracing::instrument(skip_all)] -async fn msg_thread_inner(ctx: &StandaloneCtx, conns: Arc>) -> Result<()> { - // Listen for commands from runner workflows - let mut sub = ctx - .subscribe::(&json!({})) - .await?; - let mut close_sub = ctx - .subscribe::(&json!({})) - .await?; - - loop { - tokio::select! { - msg = sub.next() => { - let msg = msg?.into_body(); - - { - let conns = conns.read().await; - - // Send command to socket - if let Some(conn) = conns.get(&msg.runner_id) { - let buf = versioned::ToClient::serialize( - protocol::ToClient::from(msg.inner).try_into()?, - conn.protocol_version - )?; - conn.tx.lock().await.send(Message::Binary(buf.into())).await?; - } else { - tracing::debug!( - runner_id=?msg.runner_id, - "received command for runner that isn't connected, ignoring" - ); - } - } - } - msg = close_sub.next() => { - let msg = msg?; + // TODO: Handle errors + // // Close WS + // let close_frame = utils::err_to_close_frame(err); + // let mut tx = conn.ws_tx.lock().await; + // if let Err(err) = tx.send(Message::Close(Some(close_frame))).await { + // tracing::error!(?runner_id, ?err, "failed closing socket"); + // } - { - let conns = conns.read().await; + // Clean up + tracing::info!(?runner_id, "connection closed"); - // Close socket - if let Some(conn) = conns.get(&msg.runner_id) { - tracing::info!(runner_id = ?msg.runner_id, "received close ws event, closing socket"); - - let close_frame = err_to_close_frame(WsError::Eviction.build()); - conn.tx.lock().await.send(Message::Close(Some(close_frame))).await?; - } else { - tracing::debug!( - runner_id=?msg.runner_id, - "received close command for runner that isn't connected, ignoring" - ); - } - } - } - } + Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()) } } - -#[derive(Clone)] -struct UrlData { - protocol_version: u16, - namespace: String, - runner_key: String, -} - -fn parse_url_from_url(url: url::Url) -> Result { - // Read protocol version from query parameters (required) - let protocol_version = url - .query_pairs() - .find_map(|(n, v)| (n == "protocol_version").then_some(v)) - .context("missing `protocol_version` query parameter")? - .parse::() - .context("invalid `protocol_version` query parameter")?; - - // Read namespace from query parameters - let namespace = url - .query_pairs() - .find_map(|(n, v)| (n == "namespace").then_some(v)) - .context("missing `namespace` query parameter")? - .to_string(); - - // Read runner key from query parameters (required) - let runner_key = url - .query_pairs() - .find_map(|(n, v)| (n == "runner_key").then_some(v)) - .context("missing `runner_key` query parameter")? - .to_string(); - - Ok(UrlData { - protocol_version, - namespace, - runner_key, - }) -} - -fn err_to_close_frame(err: anyhow::Error) -> CloseFrame { - let rivet_err = err - .chain() - .find_map(|x| x.downcast_ref::()) - .cloned() - .unwrap_or_else(|| RivetError::from(&INTERNAL_ERROR)); - - let code = match (rivet_err.group(), rivet_err.code()) { - ("ws", "connection_closed") => CloseCode::Normal, - _ => CloseCode::Error, - }; - - // NOTE: reason cannot be more than 123 bytes as per the WS protocol - let reason = util::safe_slice( - &format!("{}.{}", rivet_err.group(), rivet_err.code()), - 0, - 123, - ) - .into(); - - CloseFrame { code, reason } -} diff --git a/packages/core/pegboard-runner/src/message_task_old.rs b/packages/core/pegboard-runner/src/message_task_old.rs new file mode 100644 index 0000000000..56de9a1280 --- /dev/null +++ b/packages/core/pegboard-runner/src/message_task_old.rs @@ -0,0 +1,375 @@ +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use gas::prelude::Id; +use gas::prelude::*; +use http_body_util::Full; +use hyper::upgrade::Upgraded; +use hyper::{Response, StatusCode}; +use hyper_tungstenite::{tungstenite::Message, HyperWebsocket}; +use hyper_util::rt::TokioIo; +use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +use pegboard_actor_kv as kv; +use rivet_error::*; +use rivet_guard_core::{ + custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +}; +use rivet_runner_protocol as protocol; +use rivet_runner_protocol::*; +use serde_json::json; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; +use tokio::sync::{Mutex, RwLock}; +use tokio_tungstenite::{ + tungstenite::protocol::frame::{coding::CloseCode, CloseFrame}, + WebSocketStream, +}; +use universalpubsub::NextOutput; + +#[tracing::instrument(skip_all)] +async fn msg_thread(ctx: &StandaloneCtx, conns: Arc>) { + loop { + match msg_thread_inner(ctx, conns.clone()).await { + Ok(_) => { + tracing::warn!("msg thread exited early"); + } + Err(err) => { + tracing::error!(?err, "msg thread error"); + } + } + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + } +} + +#[tracing::instrument(skip_all)] +async fn msg_thread_inner(ctx: &StandaloneCtx, conns: Arc>) -> Result<()> { + // Listen for commands from runner workflows + let runner_id: Id = todo!(); + let topic = pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string(); + let mut sub = ctx.ups()?.subscribe(&topic).await?; + + loop { + tokio::select! { + msg = sub.next() => { + // Parse message + let msg = match msg? { + NextOutput::Message(ups_msg) => { + tracing::info!( + payload_len = ups_msg.payload.len(), + "received message from pubsub, forwarding to WebSocket" + ); + + // Parse message + let msg = match versioned::ToClient::deserialize_with_embedded_version( + &ups_msg.payload, + ) { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to parse tunnel message"); + continue; + } + }; + + msg + } + NextOutput::Unsubscribed => { + tracing::info!("runner subscription unsubscribed"); + // TODO: Handle close like below + return Ok(()); + } + }; + + { + let conns = conns.read().await; + + // Send command to socket + if let Some(conn) = conns.get(&runner_id) { + let buf = versioned::ToClient::serialize( + versioned::ToClient::latest(msg), + conn.protocol_version + )?; + conn.tx.lock().await.send(Message::Binary(buf.into())).await?; + } else { + tracing::debug!( + ?runner_id, + "received command for runner that isn't connected, ignoring" + ); + } + } + } + // msg = close_sub.next() => { + // let msg = msg?; + // + // { + // let conns = conns.read().await; + // + // // Close socket + // if let Some(conn) = conns.get(&msg.runner_id) { + // tracing::info!(runner_id = ?msg.runner_id, "received close ws event, closing socket"); + // + // let close_frame = err_to_close_frame(WsError::Eviction.build()); + // conn.tx.lock().await.send(Message::Close(Some(close_frame))).await?; + // } else { + // tracing::debug!( + // runner_id=?msg.runner_id, + // "received close command for runner that isn't connected, ignoring" + // ); + // } + // } + // } + } + } +} + +async fn handle_messages( + ctx: &StandaloneCtx, + rx: &mut futures_util::stream::SplitStream, + runner_id: Id, + conn: &Connection, +) -> Result<()> { + // Receive messages from socket + while let Some(msg) = rx.next().await { + let buf = match msg? { + Message::Binary(buf) => buf, + Message::Ping(_) => continue, + Message::Close(_) => bail!("socket closed {}", runner_id), + msg => { + tracing::warn!(?runner_id, ?msg, "unexpected message"); + continue; + } + }; + + let packet = versioned::ToServer::deserialize(&buf, conn.protocol_version)?; + + match packet { + ToServer::ToServerPing(ping) => { + let rtt = util::timestamp::now().saturating_sub(ping.ts).try_into()?; + + conn.last_rtt.store(rtt, Ordering::Relaxed); + } + // Process KV request + ToServer::ToServerKvRequest(req) => { + let actor_id = match Id::parse(&req.actor_id) { + Ok(actor_id) => actor_id, + Err(err) => { + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: KvResponseData::KvErrorResponse(KvErrorResponse { + message: err.to_string(), + }), + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + + continue; + } + }; + + let actors_res = ctx + .op(pegboard::ops::actor::get_runner::Input { + actor_ids: vec![actor_id], + }) + .await?; + let actor_belongs = actors_res + .actors + .first() + .map(|x| x.runner_id == runner_id) + .unwrap_or_default(); + + // Verify actor belongs to this runner + if !actor_belongs { + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: KvResponseData::KvErrorResponse(KvErrorResponse { + message: "given actor does not belong to runner".to_string(), + }), + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + + continue; + } + + // TODO: Add queue and bg thread for processing kv ops + // Run kv operation + match req.data { + KvRequestData::KvGetRequest(body) => { + let res = kv::get(&*ctx.udb()?, actor_id, body.keys).await; + + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok((keys, values, metadata)) => { + KvResponseData::KvGetResponse(KvGetResponse { + keys, + values, + metadata, + }) + } + Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }), + }, + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + } + KvRequestData::KvListRequest(body) => { + let res = kv::list( + &*ctx.udb()?, + actor_id, + body.query, + body.reverse.unwrap_or_default(), + body.limit.map(TryInto::try_into).transpose()?, + ) + .await; + + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok((keys, values, metadata)) => { + KvResponseData::KvListResponse(KvListResponse { + keys, + values, + metadata, + }) + } + Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }), + }, + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + } + KvRequestData::KvPutRequest(body) => { + let res = kv::put(&*ctx.udb()?, actor_id, body.keys, body.values).await; + + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => KvResponseData::KvPutResponse, + Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }), + }, + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + } + KvRequestData::KvDeleteRequest(body) => { + let res = kv::delete(&*ctx.udb()?, actor_id, body.keys).await; + + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => KvResponseData::KvDeleteResponse, + Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }), + }, + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + } + KvRequestData::KvDropRequest => { + let res = kv::delete_all(&*ctx.udb()?, actor_id).await; + + let packet = versioned::ToClient::latest(ToClient::ToClientKvResponse( + ToClientKvResponse { + request_id: req.request_id, + data: match res { + Ok(()) => KvResponseData::KvDropResponse, + Err(err) => KvResponseData::KvErrorResponse(KvErrorResponse { + // TODO: Don't return actual error? + message: err.to_string(), + }), + }, + }, + )); + + let buf = packet.serialize(conn.protocol_version)?; + conn.tx + .lock() + .await + .send(Message::Binary(buf.into())) + .await?; + } + } + } + protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { + todo!() + } + // Forward to runner wf + protocol::ToServer::ToServerInit(_) + | protocol::ToServer::ToServerEvents(_) + | protocol::ToServer::ToServerAckCommands(_) + | protocol::ToServer::ToServerStopping => { + ctx.signal(pegboard::workflows::runner::Forward { + inner: protocol::ToServer::try_from(packet)?, + }) + .to_workflow_id(conn.workflow_id) + .send() + .await?; + } + } + } + + bail!("stream closed {runner_id}"); +} diff --git a/packages/core/pegboard-runner/src/ping_task.rs b/packages/core/pegboard-runner/src/ping_task.rs new file mode 100644 index 0000000000..601f6a1f69 --- /dev/null +++ b/packages/core/pegboard-runner/src/ping_task.rs @@ -0,0 +1,125 @@ +// use async_trait::async_trait; +// use bytes::Bytes; +// use futures_util::{ +// stream::{SplitSink, SplitStream}, +// SinkExt, StreamExt, +// }; +// use gas::prelude::Id; +// use gas::prelude::*; +// use http_body_util::Full; +// use hyper::upgrade::Upgraded; +// use hyper::{Response, StatusCode}; +// use hyper_tungstenite::{tungstenite::Message, HyperWebsocket}; +// use hyper_util::rt::TokioIo; +// use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +// use pegboard_actor_kv as kv; +// use rivet_error::*; +// use rivet_guard_core::{ +// custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +// }; +// use rivet_runner_protocol as protocol; +// use rivet_runner_protocol::*; +// use serde_json::json; +// use std::{ +// collections::HashMap, +// sync::{ +// atomic::{AtomicU32, Ordering}, +// Arc, +// }, +// time::Duration, +// }; +// use tokio::sync::{Mutex, RwLock}; +// use tokio_tungstenite::{ +// tungstenite::protocol::frame::{coding::CloseCode, CloseFrame}, +// WebSocketStream, +// }; +// use universalpubsub::NextOutput; +// +// #[tracing::instrument(skip_all)] +// async fn update_ping_thread(ctx: &StandaloneCtx, conns: Arc>) { +// loop { +// match update_ping_thread_inner(ctx, conns.clone()).await { +// Ok(_) => { +// tracing::warn!("update ping thread thread exited early"); +// } +// Err(err) => { +// tracing::error!(?err, "update ping thread error"); +// } +// } +// +// tokio::time::sleep(std::time::Duration::from_secs(2)).await; +// } +// } +// +// /// Updates the ping of all runners requesting a ping update at once. +// #[tracing::instrument(skip_all)] +// async fn update_ping_thread_inner( +// ctx: &StandaloneCtx, +// conns: Arc>, +// ) -> Result<()> { +// loop { +// tokio::time::sleep(UPDATE_PING_INTERVAL).await; +// +// let runners = { +// let mut conns = conns.write().await; +// +// // Select all runners that required a ping update +// conns +// .iter_mut() +// .map(|(runner_id, conn)| { +// ( +// *runner_id, +// conn.workflow_id, +// conn.last_rtt.load(Ordering::Relaxed), +// ) +// }) +// .collect::>() +// }; +// +// if runners.is_empty() { +// continue; +// } +// +// let mut runners2 = Vec::new(); +// +// // TODO: Parallelize +// // Filter out dead wfs +// for (runner_id, workflow_id, rtt) in runners { +// let Some(wf) = ctx +// .workflow::(workflow_id) +// .get() +// .await? +// else { +// tracing::error!(?runner_id, "workflow does not exist"); +// continue; +// }; +// +// // Only update ping if the workflow is not dead +// if wf.has_wake_condition { +// runners2.push(pegboard::ops::runner::update_alloc_idx::Runner { +// runner_id, +// action: Action::UpdatePing { rtt }, +// }); +// } +// } +// +// if runners2.is_empty() { +// continue; +// } +// +// let res = ctx +// .op(pegboard::ops::runner::update_alloc_idx::Input { runners: runners2 }) +// .await?; +// +// for notif in res.notifications { +// if let RunnerEligibility::ReEligible = notif.eligibility { +// tracing::debug!(runner_id=?notif.runner_id, "runner has become eligible again"); +// +// ctx.signal(pegboard::workflows::runner::CheckQueue {}) +// .to_workflow_id(notif.workflow_id) +// .send() +// .await?; +// } +// } +// } +// } diff --git a/packages/core/pegboard-runner/src/pubsub_to_client_task.rs b/packages/core/pegboard-runner/src/pubsub_to_client_task.rs new file mode 100644 index 0000000000..0e59fb94b0 --- /dev/null +++ b/packages/core/pegboard-runner/src/pubsub_to_client_task.rs @@ -0,0 +1,90 @@ +use anyhow::Result; +use futures_util::SinkExt; +use gas::prelude::*; +use hyper::upgrade::Upgraded; +use hyper_tungstenite::tungstenite::Message as WsMessage; +use hyper_util::rt::TokioIo; +use rivet_runner_protocol::{self as protocol, versioned}; +use std::sync::Arc; +use tokio_tungstenite::WebSocketStream; +use universalpubsub::{NextOutput, Subscriber}; +use versioned_data_util::OwnedVersionedData as _; + +use crate::{ + conn::{Conn, TunnelActiveRequest}, + utils::{self, WebSocketSender}, +}; + +pub async fn task(ctx: StandaloneCtx, conn: Arc, sub: Subscriber) { + match task_inner(ctx, conn, sub).await { + Ok(_) => {} + Err(err) => { + tracing::error!(?err, "pubsub to client error"); + } + } +} + +async fn task_inner(ctx: StandaloneCtx, conn: Arc, mut sub: Subscriber) -> Result<()> { + while let Result::Ok(NextOutput::Message(ups_msg)) = sub.next().await { + tracing::info!( + payload_len = ups_msg.payload.len(), + "received message from pubsub, forwarding to WebSocket" + ); + + // Parse message + let mut msg = match versioned::ToClient::deserialize_with_embedded_version(&ups_msg.payload) + { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to parse tunnel message"); + continue; + } + }; + + // Handle tunnel messages + if let protocol::ToClient::ToClientTunnelMessage(tunnel_msg) = &mut msg { + handle_tunnel_message(&conn, tunnel_msg).await; + } + + // Forward raw message to WebSocket + let serialized_msg = + match versioned::ToClient::latest(msg).serialize_version(conn.protocol_version) { + Result::Ok(x) => x, + Err(err) => { + tracing::error!(?err, "failed to serialize tunnel message"); + continue; + } + }; + let ws_msg = WsMessage::Binary(serialized_msg.into()); + if let Err(e) = conn.ws_tx.lock().await.send(ws_msg).await { + tracing::error!(?e, "failed to send message to WebSocket"); + break; + } + } + + Ok(()) +} + +async fn handle_tunnel_message(conn: &Arc, msg: &mut protocol::ToClientTunnelMessage) { + // Save active request + // + // This will remove gateway_reply_to from the message since it does not need to be sent to the + // client + if let Some(reply_to) = msg.gateway_reply_to.take() { + tracing::debug!(?msg.request_id, ?reply_to, "creating active request"); + let mut active_requests = conn.tunnel_active_requests.lock().await; + active_requests.insert( + msg.request_id, + TunnelActiveRequest { + gateway_reply_to: reply_to, + }, + ); + } + + // If terminal, remove active request tracking + if utils::is_to_client_tunnel_message_kind_request_close(&msg.message_kind) { + tracing::debug!(?msg.request_id, "removing active conn from close message"); + let mut active_requests = conn.tunnel_active_requests.lock().await; + active_requests.remove(&msg.request_id); + } +} diff --git a/packages/core/pegboard-runner/src/utils.rs b/packages/core/pegboard-runner/src/utils.rs new file mode 100644 index 0000000000..d29c9d53b8 --- /dev/null +++ b/packages/core/pegboard-runner/src/utils.rs @@ -0,0 +1,128 @@ +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::{ + SinkExt, StreamExt, + stream::{SplitSink, SplitStream}, +}; +use gas::prelude::Id; +use gas::prelude::*; +use http_body_util::Full; +use hyper::upgrade::Upgraded; +use hyper::{Response, StatusCode}; +use hyper_tungstenite::tungstenite::Message as WsMessage; +use hyper_tungstenite::{HyperWebsocket, tungstenite::Message}; +use hyper_util::rt::TokioIo; +use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +use pegboard_actor_kv as kv; +use rivet_error::*; +use rivet_guard_core::{ + custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, +}; +use rivet_runner_protocol as protocol; +use rivet_runner_protocol::*; +use serde_json::json; +use std::{ + collections::HashMap, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, + time::Duration, +}; +use tokio::sync::{Mutex, RwLock}; +use tokio_tungstenite::{ + WebSocketStream, + tungstenite::protocol::frame::{CloseFrame, coding::CloseCode}, +}; +use universalpubsub::NextOutput; + +pub type WebSocketReceiver = futures_util::stream::SplitStream>>; + +pub type WebSocketSender = + futures_util::stream::SplitSink>, WsMessage>; + +#[derive(Clone)] +pub struct UrlData { + pub protocol_version: u16, + pub namespace: String, + pub runner_key: String, +} + +pub fn parse_url_from_url(url: url::Url) -> Result { + // Read protocol version from query parameters (required) + let protocol_version = url + .query_pairs() + .find_map(|(n, v)| (n == "protocol_version").then_some(v)) + .context("missing `protocol_version` query parameter")? + .parse::() + .context("invalid `protocol_version` query parameter")?; + + // Read namespace from query parameters + let namespace = url + .query_pairs() + .find_map(|(n, v)| (n == "namespace").then_some(v)) + .context("missing `namespace` query parameter")? + .to_string(); + + // Read runner key from query parameters (required) + let runner_key = url + .query_pairs() + .find_map(|(n, v)| (n == "runner_key").then_some(v)) + .context("missing `runner_key` query parameter")? + .to_string(); + + Ok(UrlData { + protocol_version, + namespace, + runner_key, + }) +} + +pub fn err_to_close_frame(err: anyhow::Error) -> CloseFrame { + let rivet_err = err + .chain() + .find_map(|x| x.downcast_ref::()) + .cloned() + .unwrap_or_else(|| RivetError::from(&INTERNAL_ERROR)); + + let code = match (rivet_err.group(), rivet_err.code()) { + ("ws", "connection_closed") => CloseCode::Normal, + _ => CloseCode::Error, + }; + + // NOTE: reason cannot be more than 123 bytes as per the WS protocol + let reason = util::safe_slice( + &format!("{}.{}", rivet_err.group(), rivet_err.code()), + 0, + 123, + ) + .into(); + + CloseFrame { code, reason } +} + +/// Determines if a given message kind will terminate the request. +pub fn is_to_server_tunnel_message_kind_request_close( + kind: &protocol::ToServerTunnelMessageKind, +) -> bool { + match kind { + // HTTP terminal states + protocol::ToServerTunnelMessageKind::ToServerResponseStart(resp) => !resp.stream, + protocol::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish, + protocol::ToServerTunnelMessageKind::ToServerResponseAbort => true, + // WebSocket terminal states (either side closes) + protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true, + _ => false, + } +} + +/// Determines if a given message kind will terminate the request. +pub fn is_to_client_tunnel_message_kind_request_close( + kind: &protocol::ToClientTunnelMessageKind, +) -> bool { + match kind { + // WebSocket terminal states (either side closes) + protocol::ToClientTunnelMessageKind::ToClientWebSocketClose(_) => true, + _ => false, + } +} diff --git a/packages/core/pegboard-serverless/src/lib.rs b/packages/core/pegboard-serverless/src/lib.rs index 5bbf64f303..824f610e11 100644 --- a/packages/core/pegboard-serverless/src/lib.rs +++ b/packages/core/pegboard-serverless/src/lib.rs @@ -12,7 +12,7 @@ use gas::prelude::*; use namespace::types::RunnerConfig; use pegboard::keys; use reqwest_eventsource as sse; -use rivet_runner_protocol::protocol; +use rivet_runner_protocol as protocol; use tokio::{sync::oneshot, task::JoinHandle, time::Duration}; use universaldb::options::StreamingMode; use universaldb::utils::IsolationLevel::*; diff --git a/packages/core/pegboard-tunnel/Cargo.toml b/packages/core/pegboard-tunnel/Cargo.toml index 22aa22e864..2a5436fbe5 100644 --- a/packages/core/pegboard-tunnel/Cargo.toml +++ b/packages/core/pegboard-tunnel/Cargo.toml @@ -22,7 +22,6 @@ rivet-error.workspace = true rivet-guard-core = { path = "../guard/core" } rivet-metrics.workspace = true rivet-pools.workspace = true -rivet-tunnel-protocol.workspace = true rivet-runtime.workspace = true rivet-util.workspace = true serde.workspace = true diff --git a/packages/services/pegboard/Cargo.toml b/packages/services/pegboard/Cargo.toml index c75e5ff91b..2d686fbf65 100644 --- a/packages/services/pegboard/Cargo.toml +++ b/packages/services/pegboard/Cargo.toml @@ -7,6 +7,7 @@ edition.workspace = true [dependencies] anyhow.workspace = true +base64.workspace = true epoxy.workspace = true gas.workspace = true lazy_static.workspace = true @@ -25,5 +26,6 @@ serde.workspace = true strum.workspace = true tracing.workspace = true universaldb.workspace = true +universalpubsub.workspace = true utoipa.workspace = true versioned-data-util.workspace = true diff --git a/packages/services/pegboard/src/lib.rs b/packages/services/pegboard/src/lib.rs index 8a08a5b9a9..a776a3d227 100644 --- a/packages/services/pegboard/src/lib.rs +++ b/packages/services/pegboard/src/lib.rs @@ -5,6 +5,7 @@ pub mod keys; mod metrics; pub mod ops; pub mod pubsub_subjects; +mod utils; pub mod workflows; pub fn registry() -> WorkflowResult { diff --git a/packages/services/pegboard/src/pubsub_subjects.rs b/packages/services/pegboard/src/pubsub_subjects.rs index 9eab33a888..3db060ad02 100644 --- a/packages/services/pegboard/src/pubsub_subjects.rs +++ b/packages/services/pegboard/src/pubsub_subjects.rs @@ -1,43 +1,33 @@ use gas::prelude::*; -pub struct TunnelRunnerReceiverSubject<'a> { - namespace_id: Id, - runner_name: &'a str, - runner_key: &'a str, +pub struct RunnerReceiverSubject { + runner_id: Id, } -impl<'a> TunnelRunnerReceiverSubject<'a> { - pub fn new(namespace_id: Id, runner_name: &'a str, runner_key: &'a str) -> Self { - Self { - namespace_id, - runner_name, - runner_key, - } +impl RunnerReceiverSubject { + pub fn new(runner_id: Id) -> Self { + Self { runner_id } } } -impl std::fmt::Display for TunnelRunnerReceiverSubject<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "pegboard.tunnel.runner_receiver.{}.{}.{}", - self.namespace_id, self.runner_name, self.runner_key - ) +impl std::fmt::Display for RunnerReceiverSubject { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "pegboard.runner.{}", self.runner_id) } } -pub struct TunnelGatewayReceiverSubject { +pub struct GatewayReceiverSubject { gateway_id: Uuid, } -impl<'a> TunnelGatewayReceiverSubject { +impl GatewayReceiverSubject { pub fn new(gateway_id: Uuid) -> Self { Self { gateway_id } } } -impl std::fmt::Display for TunnelGatewayReceiverSubject { +impl std::fmt::Display for GatewayReceiverSubject { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "pegboard.gateway.receiver.{}", self.gateway_id) + write!(f, "pegboard.gateway.{}", self.gateway_id) } } diff --git a/packages/services/pegboard/src/utils.rs b/packages/services/pegboard/src/utils.rs new file mode 100644 index 0000000000..61bba809ae --- /dev/null +++ b/packages/services/pegboard/src/utils.rs @@ -0,0 +1,29 @@ +use rivet_runner_protocol as protocol; + +pub fn event_actor_id(event: &protocol::Event) -> &str { + match event { + protocol::Event::EventActorIntent(protocol::EventActorIntent { actor_id, .. }) => actor_id, + protocol::Event::EventActorStateUpdate(protocol::EventActorStateUpdate { + actor_id, + .. + }) => actor_id, + protocol::Event::EventActorSetAlarm(protocol::EventActorSetAlarm { actor_id, .. }) => { + actor_id + } + } +} + +pub fn event_generation(event: &protocol::Event) -> u32 { + match event { + protocol::Event::EventActorIntent(protocol::EventActorIntent { generation, .. }) => { + *generation + } + protocol::Event::EventActorStateUpdate(protocol::EventActorStateUpdate { + generation, + .. + }) => *generation, + protocol::Event::EventActorSetAlarm(protocol::EventActorSetAlarm { + generation, .. + }) => *generation, + } +} diff --git a/packages/services/pegboard/src/workflows/actor/destroy.rs b/packages/services/pegboard/src/workflows/actor/destroy.rs index f267328831..a410cef164 100644 --- a/packages/services/pegboard/src/workflows/actor/destroy.rs +++ b/packages/services/pegboard/src/workflows/actor/destroy.rs @@ -1,6 +1,6 @@ use gas::prelude::*; use rivet_data::converted::ActorByKeyKeyData; -use rivet_runner_protocol::protocol; +use rivet_runner_protocol as protocol; use universaldb::options::MutationType; use universaldb::utils::IsolationLevel::*; @@ -254,9 +254,11 @@ pub(crate) async fn kill( generation: u32, runner_workflow_id: Id, ) -> Result<()> { - ctx.signal(protocol::Command::StopActor { - actor_id, - generation, + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStopActor(protocol::CommandStopActor { + actor_id: actor_id.to_string(), + generation, + }), }) .to_workflow_id(runner_workflow_id) .send() diff --git a/packages/services/pegboard/src/workflows/actor/mod.rs b/packages/services/pegboard/src/workflows/actor/mod.rs index 043ca2e802..8a0153288e 100644 --- a/packages/services/pegboard/src/workflows/actor/mod.rs +++ b/packages/services/pegboard/src/workflows/actor/mod.rs @@ -1,6 +1,6 @@ use futures_util::FutureExt; use gas::prelude::*; -use rivet_runner_protocol::protocol; +use rivet_runner_protocol as protocol; use rivet_types::actors::CrashPolicy; use crate::{errors, workflows::runner::AllocatePendingActorsInput}; @@ -270,13 +270,16 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> match sig { Main::Event(sig) => { // Ignore state updates for previous generations - if sig.generation() != state.generation { + if crate::utils::event_generation(&sig.inner) != state.generation { return Ok(Loop::Continue); } - match sig { - protocol::Event::ActorIntent { intent, .. } => match intent { - protocol::ActorIntent::Sleep => { + match sig.inner { + protocol::Event::EventActorIntent(protocol::EventActorIntent { + intent, + .. + }) => match intent { + protocol::ActorIntent::ActorIntentSleep => { state.gc_timeout_ts = Some(util::timestamp::now() + ACTOR_STOP_THRESHOLD_MS); state.sleeping = true; @@ -295,7 +298,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> ) .await?; } - protocol::ActorIntent::Stop => { + protocol::ActorIntent::ActorIntentStop => { state.gc_timeout_ts = Some(util::timestamp::now() + ACTOR_STOP_THRESHOLD_MS); @@ -313,10 +316,12 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .await?; } }, - protocol::Event::ActorStateUpdate { - state: actor_state, .. - } => match actor_state { - protocol::ActorState::Running => { + protocol::Event::EventActorStateUpdate( + protocol::EventActorStateUpdate { + state: actor_state, .. + }, + ) => match actor_state { + protocol::ActorState::ActorStateRunning => { state.gc_timeout_ts = None; ctx.activity(runtime::SetStartedInput { @@ -331,7 +336,9 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .send() .await?; } - protocol::ActorState::Stopped { code, .. } => { + protocol::ActorState::ActorStateStopped( + protocol::ActorStateStopped { code, .. }, + ) => { if let Some(res) = handle_stopped(ctx, &input, state, Some(code), false) .await? @@ -340,7 +347,9 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> } } }, - protocol::Event::ActorSetAlarm { alarm_ts, .. } => { + protocol::Event::EventActorSetAlarm( + protocol::EventActorSetAlarm { alarm_ts, .. }, + ) => { state.alarm_ts = alarm_ts; } } @@ -546,6 +555,11 @@ pub struct Allocate { pub runner_workflow_id: Id, } +#[signal("pegboard_actor_event")] +pub struct Event { + pub inner: protocol::Event, +} + #[signal("pegboard_actor_wake")] pub struct Wake {} @@ -570,7 +584,7 @@ join_signal!(PendingAllocation { }); join_signal!(Main { - Event(protocol::Event), + Event(Event), Wake, Lost, Destroy, diff --git a/packages/services/pegboard/src/workflows/actor/runtime.rs b/packages/services/pegboard/src/workflows/actor/runtime.rs index ffd1c25fd4..942282cf2f 100644 --- a/packages/services/pegboard/src/workflows/actor/runtime.rs +++ b/packages/services/pegboard/src/workflows/actor/runtime.rs @@ -1,10 +1,11 @@ -use std::time::Instant; - +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; use futures_util::StreamExt; use futures_util::{FutureExt, TryStreamExt}; use gas::prelude::*; use rivet_metrics::KeyValue; -use rivet_runner_protocol::protocol; +use rivet_runner_protocol as protocol; +use std::time::Instant; use universaldb::options::{ConflictRangeType, MutationType, StreamingMode}; use universaldb::utils::{FormalKey, IsolationLevel::*}; @@ -457,16 +458,22 @@ pub async fn spawn_actor( } }; - ctx.signal(protocol::Command::StartActor { - actor_id: input.actor_id, - generation, - config: Box::new(protocol::ActorConfig { - name: input.name.clone(), - key: input.key.clone(), - // HACK: We should not use dynamic timestamp here, but we don't validate if signal data - // changes (like activity inputs) so this is fine for now. - create_ts: util::timestamp::now(), - input: input.input.clone(), + ctx.signal(crate::workflows::runner::Command { + inner: protocol::Command::CommandStartActor(protocol::CommandStartActor { + actor_id: input.actor_id.to_string(), + generation, + config: protocol::ActorConfig { + name: input.name.clone(), + key: input.key.clone(), + // HACK: We should not use dynamic timestamp here, but we don't validate if signal data + // changes (like activity inputs) so this is fine for now. + create_ts: util::timestamp::now(), + input: input + .input + .as_ref() + .map(|x| BASE64_STANDARD.decode(x)) + .transpose()?, + }, }), }) .to_workflow_id(allocate_res.runner_workflow_id) diff --git a/packages/services/pegboard/src/workflows/runner.rs b/packages/services/pegboard/src/workflows/runner.rs index 98aa6102a0..8da7c531aa 100644 --- a/packages/services/pegboard/src/workflows/runner.rs +++ b/packages/services/pegboard/src/workflows/runner.rs @@ -1,9 +1,13 @@ use futures_util::{FutureExt, StreamExt, TryStreamExt}; use gas::prelude::*; use rivet_data::converted::{ActorNameKeyData, MetadataKeyData, RunnerByKeyKeyData}; -use rivet_runner_protocol::protocol; -use universaldb::options::{ConflictRangeType, StreamingMode}; -use universaldb::utils::{FormalChunkedKey, IsolationLevel::*}; +use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use universaldb::{ + options::{ConflictRangeType, StreamingMode}, + utils::{FormalChunkedKey, IsolationLevel::*}, +}; +use universalpubsub::PublishOpts; +use versioned_data_util::OwnedVersionedData as _; use crate::{keys, workflows::actor::Allocate}; @@ -67,10 +71,12 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> // Evict other workflow if there was a key conflict if let Some(evict_workflow_id) = init_res.evict_workflow_id { - ctx.signal(protocol::ToServer::Stopping) - .to_workflow_id(evict_workflow_id) - .send() - .await?; + ctx.signal(Forward { + inner: protocol::ToServer::ToServerStopping, + }) + .to_workflow_id(evict_workflow_id) + .send() + .await?; } ctx.loope(LifecycleState::new(), |ctx, state| { @@ -82,13 +88,13 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .await? { Some(Main::Forward(sig)) => { - match sig { - protocol::ToServer::Init { + match sig.inner { + protocol::ToServer::ToServerInit(protocol::ToServerInit { last_command_idx, prepopulate_actor_names, metadata, .. - } => { + }) => { let init_data = ctx .activity(ProcessInitInput { runner_id: input.runner_id, @@ -100,26 +106,26 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .await?; // Send init packet - ctx.msg(ToWs { + ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - inner: protocol::ToClient::Init { - runner_id: input.runner_id, + message: protocol::ToClient::ToClientInit(protocol::ToClientInit { + runner_id: input.runner_id.to_string(), last_event_idx: init_data.last_event_idx, metadata: protocol::ProtocolMetadata { runner_lost_threshold: RUNNER_LOST_THRESHOLD_MS, }, - }, + }), }) - .send() .await?; // Send missed commands if !init_data.missed_commands.is_empty() { - ctx.msg(ToWs { + ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - inner: protocol::ToClient::Commands(init_data.missed_commands), + message: protocol::ToClient::ToClientCommands( + init_data.missed_commands, + ), }) - .send() .await?; } @@ -152,17 +158,18 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .await?; } } - protocol::ToServer::Events(events) => { + protocol::ToServer::ToServerEvents(events) => { let last_event_idx = events.last().map(|event| event.index); // NOTE: This should not be parallelized because signals should be sent in order // Forward to actor workflows for event in events { - let actor_id = event.inner.actor_id(); + let actor_id = + crate::utils::event_actor_id(&event.inner).to_string(); let res = ctx - .signal(event.inner) + .signal(crate::workflows::actor::Event { inner: event.inner }) .to_workflow::() - .tag("actor_id", actor_id) + .tag("actor_id", &actor_id) .send() .await; @@ -184,21 +191,24 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> if last_event_idx > state.last_event_ack_idx.saturating_add(500) { state.last_event_ack_idx = last_event_idx; - ctx.msg(ToWs { + ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - inner: protocol::ToClient::AckEvents { - last_event_idx: state.last_event_ack_idx, - }, + message: protocol::ToClient::ToClientAckEvents( + protocol::ToClientAckEvents { + last_event_idx: state.last_event_ack_idx, + }, + ), }) - .send() .await?; } } } - protocol::ToServer::AckCommands { last_command_idx } => { + protocol::ToServer::ToServerAckCommands( + protocol::ToServerAckCommands { last_command_idx }, + ) => { ctx.activity(AckCommandsInput { last_command_idx }).await?; } - protocol::ToServer::Stopping => { + protocol::ToServer::ToServerStopping => { // The workflow will enter a draining state where it can still process signals if // needed. After RUNNER_LOST_THRESHOLD_MS it will exit this loop and stop. state.draining = true; @@ -234,9 +244,13 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> let commands = actors .into_iter() - .map(|(actor_id, generation)| protocol::Command::StopActor { - actor_id, - generation, + .map(|(actor_id, generation)| { + protocol::Command::CommandStopActor( + protocol::CommandStopActor { + actor_id: actor_id.to_string(), + generation, + }, + ) }) .collect::>(); @@ -246,9 +260,9 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> }) .await?; - ctx.msg(ToWs { + ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - inner: protocol::ToClient::Commands( + message: protocol::ToClient::ToClientCommands( commands .into_iter() .enumerate() @@ -259,22 +273,29 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .collect(), ), }) - .send() .await?; } } + protocol::ToServer::ToServerPing(_) + | protocol::ToServer::ToServerKvRequest(_) + | protocol::ToServer::ToServerTunnelMessage(_) => { + bail!( + "received message that should not be sent to runner workflow: {:?}", + sig.inner + ) + } } } Some(Main::Command(command)) => { // If draining, ignore start actor command and inform the actor wf that it is lost if let ( - protocol::Command::StartActor { + protocol::Command::CommandStartActor(protocol::CommandStartActor { actor_id, generation, .. - }, + }), true, - ) = (&command, state.draining) + ) = (&command.inner, state.draining) { tracing::warn!(?actor_id, "attempt to schedule actor to draining runner"); @@ -302,19 +323,20 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> } else { let index = ctx .activity(InsertCommandsInput { - commands: vec![command.clone()], + commands: vec![command.inner.clone()], }) .await?; // Forward - ctx.msg(ToWs { + ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - inner: protocol::ToClient::Commands(vec![protocol::CommandWrapper { - index, - inner: command, - }]), + message: protocol::ToClient::ToClientCommands(vec![ + protocol::CommandWrapper { + index, + inner: command.inner, + }, + ]), }) - .send() .await?; } } @@ -393,10 +415,10 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> } // Close websocket connection (its unlikely to be open) - ctx.msg(CloseWs { + ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, + message: protocol::ToClient::ToClientClose, }) - .send() .await?; Ok(()) @@ -1093,23 +1115,43 @@ pub(crate) async fn allocate_pending_actors( Ok(AllocatePendingActorsOutput { allocations: res }) } -#[message("pegboard_runner_to_ws")] -pub struct ToWs { - pub runner_id: Id, - pub inner: protocol::ToClient, +#[derive(Debug, Serialize, Deserialize, Hash)] +struct SendMessageToRunnerInput { + runner_id: Id, + message: protocol::ToClient, +} + +#[activity(SendMessageToRunner)] +async fn send_message_to_runner(ctx: &ActivityCtx, input: &SendMessageToRunnerInput) -> Result<()> { + let receiver_subject = + crate::pubsub_subjects::RunnerReceiverSubject::new(input.runner_id).to_string(); + + let message_serialized = versioned::ToClient::latest(input.message.clone()) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + + ctx.ups()? + .publish(&receiver_subject, &message_serialized, PublishOpts::one()) + .await?; + + Ok(()) } #[signal("pegboard_runner_check_queue")] pub struct CheckQueue {} -#[message("pegboard_runner_close_ws")] -pub struct CloseWs { - pub runner_id: Id, +#[signal("pegboard_runner_command")] +pub struct Command { + pub inner: protocol::Command, +} + +#[signal("pegboard_runner_forward")] +pub struct Forward { + pub inner: protocol::ToServer, } join_signal!(Main { - Command(protocol::Command), + Command(Command), // Forwarded from the ws to this workflow - Forward(protocol::ToServer), + Forward(Forward), CheckQueue, }); diff --git a/sdks/rust/runner-protocol/src/lib.rs b/sdks/rust/runner-protocol/src/lib.rs index 9c36f980fb..676c99e464 100644 --- a/sdks/rust/runner-protocol/src/lib.rs +++ b/sdks/rust/runner-protocol/src/lib.rs @@ -1,5 +1,4 @@ pub mod generated; -pub mod protocol; pub mod versioned; // Re-export latest diff --git a/sdks/rust/runner-protocol/src/protocol.rs b/sdks/rust/runner-protocol/src/protocol.rs deleted file mode 100644 index 67d7644a18..0000000000 --- a/sdks/rust/runner-protocol/src/protocol.rs +++ /dev/null @@ -1,160 +0,0 @@ -use gas::prelude::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ToClient { - Init { - runner_id: Id, - last_event_idx: i64, - metadata: ProtocolMetadata, - }, - Commands(Vec), - AckEvents { - last_event_idx: i64, - }, -} - -#[signal("pegboard_to_server")] -#[derive(Debug)] -#[serde(rename_all = "snake_case")] -pub enum ToServer { - Init { - name: String, - version: u32, - total_slots: u32, - - last_command_idx: Option, - prepopulate_actor_names: Option>, - metadata: Option, - }, - Events(Vec), - AckCommands { - last_command_idx: i64, - }, - Stopping, -} - -#[derive(Debug, Serialize, Deserialize, Hash)] -pub struct ActorName { - /// JSON. - pub metadata: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ProtocolMetadata { - pub runner_lost_threshold: i64, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CommandWrapper { - pub index: i64, - pub inner: Command, -} - -#[signal("pegboard_command")] -#[derive(Debug, Clone, Hash)] -#[serde(rename_all = "snake_case")] -pub enum Command { - StartActor { - actor_id: Id, - generation: u32, - config: Box, - }, - StopActor { - actor_id: Id, - generation: u32, - }, -} - -#[derive(Debug, Serialize, Deserialize, Clone, Hash)] -pub struct ActorConfig { - pub name: String, - pub key: Option, - pub create_ts: i64, - /// Arbitrary user-defined binary data, base64 encoded. - pub input: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Hash)] -pub struct EventWrapper { - pub index: i64, - pub inner: Event, -} - -#[signal("pegboard_event")] -#[derive(Debug, Clone, Hash)] -#[serde(rename_all = "snake_case")] -pub enum Event { - ActorIntent { - actor_id: Id, - generation: u32, - intent: ActorIntent, - }, - ActorStateUpdate { - actor_id: Id, - generation: u32, - state: ActorState, - }, - ActorSetAlarm { - actor_id: Id, - generation: u32, - alarm_ts: Option, - }, -} - -impl Event { - // For now, all events are actor related so they doesn't need to return an `Option` - pub fn actor_id(&self) -> Id { - match self { - Event::ActorIntent { actor_id, .. } => *actor_id, - Event::ActorStateUpdate { actor_id, .. } => *actor_id, - Event::ActorSetAlarm { actor_id, .. } => *actor_id, - } - } - - // For now, all events are actor related so they doesn't need to return an `Option` - pub fn generation(&self) -> u32 { - match self { - Event::ActorIntent { generation, .. } => *generation, - Event::ActorStateUpdate { generation, .. } => *generation, - Event::ActorSetAlarm { generation, .. } => *generation, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Hash)] -#[serde(rename_all = "snake_case")] -pub enum ActorIntent { - /// Actor intends to sleep. This informs rivet that the actor should be stopped and can be woken up later - // either by an alarm or guard. - Sleep, - /// Actor intends to stop. - Stop, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Hash)] -#[serde(rename_all = "snake_case")] -pub enum ActorState { - /// Actor is running. - Running, - /// Actor stopped on runner. - Stopped { - code: StopCode, - message: Option, - }, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, Hash)] -#[serde(rename_all = "snake_case")] -pub enum StopCode { - Ok, - Error, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, Hash)] -#[serde(rename_all = "snake_case")] -pub enum WebsocketCloseReason { - Ok, - Error, -} diff --git a/sdks/rust/runner-protocol/src/versioned.rs b/sdks/rust/runner-protocol/src/versioned.rs index 954338905c..dca8eb7858 100644 --- a/sdks/rust/runner-protocol/src/versioned.rs +++ b/sdks/rust/runner-protocol/src/versioned.rs @@ -1,9 +1,7 @@ use anyhow::{Ok, Result, bail}; -use base64::prelude::*; -use gas::prelude::*; use versioned_data_util::OwnedVersionedData; -use crate::{PROTOCOL_VERSION, generated::v1, protocol}; +use crate::{PROTOCOL_VERSION, generated::v1}; pub enum ToClient { V1(v1::ToClient), @@ -45,94 +43,6 @@ impl ToClient { } } -impl TryFrom for ToClient { - type Error = anyhow::Error; - - fn try_from(value: protocol::ToClient) -> Result { - Ok(ToClient::V1(match value { - protocol::ToClient::Init { - runner_id, - last_event_idx, - metadata, - } => v1::ToClient::ToClientInit(v1::ToClientInit { - runner_id: runner_id.to_string(), - last_event_idx, - metadata: metadata.try_into()?, - }), - protocol::ToClient::Commands(commands) => { - let commands = commands - .into_iter() - .map(|c| c.try_into()) - .collect::>()?; - - v1::ToClient::ToClientCommands(commands) - } - protocol::ToClient::AckEvents { last_event_idx } => { - v1::ToClient::ToClientAckEvents(v1::ToClientAckEvents { last_event_idx }) - } - })) - } -} - -impl TryFrom for v1::ProtocolMetadata { - type Error = anyhow::Error; - - fn try_from(value: protocol::ProtocolMetadata) -> Result { - Ok(v1::ProtocolMetadata { - runner_lost_threshold: value.runner_lost_threshold, - }) - } -} - -impl TryFrom for v1::CommandWrapper { - type Error = anyhow::Error; - - fn try_from(value: protocol::CommandWrapper) -> Result { - Ok(v1::CommandWrapper { - index: value.index, - inner: value.inner.try_into()?, - }) - } -} - -impl TryFrom for v1::Command { - type Error = anyhow::Error; - - fn try_from(value: protocol::Command) -> Result { - match value { - protocol::Command::StartActor { - actor_id, - generation, - config, - } => Ok(v1::Command::CommandStartActor(v1::CommandStartActor { - actor_id: actor_id.to_string(), - generation, - config: (*config).try_into()?, - })), - protocol::Command::StopActor { - actor_id, - generation, - } => Ok(v1::Command::CommandStopActor(v1::CommandStopActor { - actor_id: actor_id.to_string(), - generation, - })), - } - } -} - -impl TryFrom for v1::ActorConfig { - type Error = anyhow::Error; - - fn try_from(value: protocol::ActorConfig) -> Result { - Ok(v1::ActorConfig { - name: value.name, - key: value.key, - create_ts: value.create_ts, - input: value.input.map(|x| BASE64_STANDARD.decode(x)).transpose()?, - }) - } -} - pub enum ToServer { V1(v1::ToServer), } @@ -173,118 +83,42 @@ impl ToServer { } } -impl From for protocol::ActorName { - fn from(value: v1::ActorName) -> Self { - protocol::ActorName { - metadata: value.metadata, - } - } +pub enum ToGateway { + V1(v1::ToGateway), } -impl TryFrom for protocol::EventWrapper { - type Error = anyhow::Error; +impl OwnedVersionedData for ToGateway { + type Latest = v1::ToGateway; - fn try_from(value: v1::EventWrapper) -> Result { - Ok(protocol::EventWrapper { - index: value.index, - inner: value.inner.try_into()?, - }) + fn latest(latest: v1::ToGateway) -> Self { + ToGateway::V1(latest) } -} -impl TryFrom for protocol::Event { - type Error = anyhow::Error; - - fn try_from(value: v1::Event) -> Result { - match value { - v1::Event::EventActorIntent(event) => Ok(protocol::Event::ActorIntent { - actor_id: util::Id::parse(&event.actor_id)?, - generation: event.generation, - intent: event.intent.try_into()?, - }), - v1::Event::EventActorStateUpdate(event) => Ok(protocol::Event::ActorStateUpdate { - actor_id: util::Id::parse(&event.actor_id)?, - generation: event.generation, - state: event.state.try_into()?, - }), - v1::Event::EventActorSetAlarm(event) => Ok(protocol::Event::ActorSetAlarm { - actor_id: util::Id::parse(&event.actor_id)?, - generation: event.generation, - alarm_ts: event.alarm_ts, - }), - } - } -} - -impl TryFrom for protocol::ActorIntent { - type Error = anyhow::Error; - - fn try_from(value: v1::ActorIntent) -> Result { - match value { - v1::ActorIntent::ActorIntentSleep => Ok(protocol::ActorIntent::Sleep), - v1::ActorIntent::ActorIntentStop => Ok(protocol::ActorIntent::Stop), + fn into_latest(self) -> Result { + #[allow(irrefutable_let_patterns)] + if let ToGateway::V1(data) = self { + Ok(data) + } else { + bail!("version not latest"); } } -} -impl TryFrom for protocol::ActorState { - type Error = anyhow::Error; - - fn try_from(value: v1::ActorState) -> Result { - match value { - v1::ActorState::ActorStateRunning => Ok(protocol::ActorState::Running), - v1::ActorState::ActorStateStopped(stopped) => Ok(protocol::ActorState::Stopped { - code: stopped.code.try_into()?, - message: stopped.message, - }), + fn deserialize_version(payload: &[u8], version: u16) -> Result { + match version { + 1 => Ok(ToGateway::V1(serde_bare::from_slice(payload)?)), + _ => bail!("invalid version: {version}"), } } -} - -impl TryFrom for protocol::StopCode { - type Error = anyhow::Error; - fn try_from(value: v1::StopCode) -> Result { - match value { - v1::StopCode::Ok => Ok(protocol::StopCode::Ok), - v1::StopCode::Error => Ok(protocol::StopCode::Error), + fn serialize_version(self, _version: u16) -> Result> { + match self { + ToGateway::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), } } } -impl TryFrom for protocol::ToServer { - type Error = anyhow::Error; - - fn try_from(value: v1::ToServer) -> Result { - match value { - v1::ToServer::ToServerInit(init) => Ok(protocol::ToServer::Init { - name: init.name, - version: init.version, - total_slots: init.total_slots, - last_command_idx: init.last_command_idx, - prepopulate_actor_names: init - .prepopulate_actor_names - .map(|x| x.into_iter().map(|(k, v)| (k, v.into())).collect()), - metadata: init.metadata, - }), - v1::ToServer::ToServerEvents(events) => Ok(protocol::ToServer::Events( - events - .into_iter() - .map(|e| e.try_into()) - .collect::>()?, - )), - v1::ToServer::ToServerAckCommands(ack) => Ok(protocol::ToServer::AckCommands { - last_command_idx: ack.last_command_idx, - }), - v1::ToServer::ToServerStopping => Ok(protocol::ToServer::Stopping), - v1::ToServer::ToServerPing(_) => { - // NOTE: Ping is handled at the websocket level and never reaches the workflow. - bail!("Ping variant should not be converted") - } - v1::ToServer::ToServerKvRequest(_) => { - // NOTE: KV is handled at the websocket level and never reaches the workflow. - bail!("KV variant should not be converted") - } - } +impl ToGateway { + pub fn serialize(self) -> Result> { + ::serialize(self, PROTOCOL_VERSION) } } diff --git a/sdks/rust/tunnel-protocol/Cargo.toml b/sdks/rust/tunnel-protocol/Cargo.toml deleted file mode 100644 index f7bc19f63d..0000000000 --- a/sdks/rust/tunnel-protocol/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "rivet-tunnel-protocol" -version.workspace = true -authors.workspace = true -license.workspace = true -edition.workspace = true - -[dependencies] -anyhow.workspace = true -base64.workspace = true -rivet-util.workspace = true -serde_bare.workspace = true -serde.workspace = true -versioned-data-util.workspace = true - -[build-dependencies] -bare_gen.workspace = true -indoc.workspace = true -prettyplease.workspace = true -serde_json.workspace = true -syn.workspace = true \ No newline at end of file diff --git a/sdks/rust/tunnel-protocol/build.rs b/sdks/rust/tunnel-protocol/build.rs deleted file mode 100644 index f43ed92983..0000000000 --- a/sdks/rust/tunnel-protocol/build.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::{ - env, fs, - path::{Path, PathBuf}, - process::Command, -}; - -use indoc::formatdoc; - -mod rust { - use super::*; - - pub fn generate_sdk(schema_dir: &Path) { - let out_dir = env::var("OUT_DIR").unwrap(); - let out_path = Path::new(&out_dir); - let mut all_names = Vec::new(); - - for entry in fs::read_dir(schema_dir).unwrap().flatten() { - let path = entry.path(); - - if path.is_dir() { - continue; - } - - let bare_name = path - .file_name() - .unwrap() - .to_str() - .unwrap() - .rsplit_once('.') - .unwrap() - .0; - - let content = - prettyplease::unparse(&syn::parse2(bare_gen::bare_schema(&path)).unwrap()); - fs::write(out_path.join(format!("{bare_name}_generated.rs")), content).unwrap(); - - all_names.push(bare_name.to_string()); - } - - let mut mod_content = String::new(); - mod_content.push_str("// Auto-generated module file for BARE schemas\n\n"); - - // Generate module declarations for each version - for name in all_names { - mod_content.push_str(&formatdoc!( - r#" - pub mod {name} {{ - include!(concat!(env!("OUT_DIR"), "/{name}_generated.rs")); - }} - "#, - )); - } - - let mod_file_path = out_path.join("combined_imports.rs"); - fs::write(&mod_file_path, mod_content).expect("Failed to write combined_imports.rs"); - } -} - -mod typescript { - use super::*; - - pub fn generate_sdk(schema_dir: &Path) { - let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let workspace_root = Path::new(&manifest_dir) - .parent() - .and_then(|p| p.parent()) - .and_then(|p| p.parent()) - .expect("Failed to find workspace root"); - - let sdk_dir = workspace_root - .join("sdks") - .join("typescript") - .join("tunnel-protocol"); - let src_dir = sdk_dir.join("src"); - - let highest_version_path = super::find_highest_version(schema_dir); - - let _ = fs::remove_dir_all(&src_dir); - if let Err(e) = fs::create_dir_all(&src_dir) { - panic!("Failed to create SDK directory: {}", e); - } - - let output = - Command::new(workspace_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js")) - .arg("compile") - .arg("--generator") - .arg("ts") - .arg(highest_version_path) - .arg("-o") - .arg(src_dir.join("index.ts")) - .output() - .expect("Failed to execute bare compiler for TypeScript"); - - if !output.status.success() { - panic!( - "BARE TypeScript generation failed: {}", - String::from_utf8_lossy(&output.stderr), - ); - } - } -} - -fn find_highest_version(schema_dir: &Path) -> PathBuf { - let mut highest_version = 0; - let mut highest_version_path = PathBuf::new(); - - for entry in fs::read_dir(schema_dir).unwrap().flatten() { - if !entry.path().is_dir() { - let path = entry.path(); - let bare_name = path - .file_name() - .unwrap() - .to_str() - .unwrap() - .split_once('.') - .unwrap() - .0; - - if let Ok(version) = bare_name[1..].parse::() { - if version > highest_version { - highest_version = version; - highest_version_path = path; - } - } - } - } - - highest_version_path -} - -fn main() { - let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let workspace_root = Path::new(&manifest_dir) - .parent() - .and_then(|p| p.parent()) - .and_then(|p| p.parent()) - .expect("Failed to find workspace root"); - - let schema_dir = workspace_root - .join("sdks") - .join("schemas") - .join("tunnel-protocol"); - - println!("cargo:rerun-if-changed={}", schema_dir.display()); - - rust::generate_sdk(&schema_dir); - - // Check if cli.js exists before attempting TypeScript generation - let cli_js_path = workspace_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js"); - if cli_js_path.exists() { - typescript::generate_sdk(&schema_dir); - } else { - println!( - "cargo:warning=TypeScript SDK generation skipped: cli.js not found at {}. Run `pnpm install` to install.", - cli_js_path.display() - ); - } -} diff --git a/sdks/rust/tunnel-protocol/src/generated.rs b/sdks/rust/tunnel-protocol/src/generated.rs deleted file mode 100644 index 84801af8dc..0000000000 --- a/sdks/rust/tunnel-protocol/src/generated.rs +++ /dev/null @@ -1 +0,0 @@ -include!(concat!(env!("OUT_DIR"), "/combined_imports.rs")); diff --git a/sdks/rust/tunnel-protocol/src/lib.rs b/sdks/rust/tunnel-protocol/src/lib.rs deleted file mode 100644 index 676c99e464..0000000000 --- a/sdks/rust/tunnel-protocol/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod generated; -pub mod versioned; - -// Re-export latest -pub use generated::v1::*; - -pub const PROTOCOL_VERSION: u16 = 1; diff --git a/sdks/rust/tunnel-protocol/src/versioned.rs b/sdks/rust/tunnel-protocol/src/versioned.rs deleted file mode 100644 index f5208f27fb..0000000000 --- a/sdks/rust/tunnel-protocol/src/versioned.rs +++ /dev/null @@ -1,92 +0,0 @@ -use anyhow::{Ok, Result, bail}; -use versioned_data_util::OwnedVersionedData; - -use crate::{PROTOCOL_VERSION, generated::v1}; - -pub enum RunnerMessage { - V1(v1::RunnerMessage), -} - -impl OwnedVersionedData for RunnerMessage { - type Latest = v1::RunnerMessage; - - fn latest(latest: v1::RunnerMessage) -> Self { - RunnerMessage::V1(latest) - } - - fn into_latest(self) -> Result { - #[allow(irrefutable_let_patterns)] - if let RunnerMessage::V1(data) = self { - Ok(data) - } else { - bail!("version not latest"); - } - } - - fn deserialize_version(payload: &[u8], version: u16) -> Result { - match version { - 1 => Ok(RunnerMessage::V1(serde_bare::from_slice(payload)?)), - _ => bail!("invalid version: {version}"), - } - } - - fn serialize_version(self, _version: u16) -> Result> { - match self { - RunnerMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), - } - } -} - -impl RunnerMessage { - pub fn deserialize(buf: &[u8]) -> Result { - ::deserialize(buf, PROTOCOL_VERSION) - } - - pub fn serialize(self) -> Result> { - ::serialize(self, PROTOCOL_VERSION) - } -} - -pub enum PubSubMessage { - V1(v1::PubSubMessage), -} - -impl OwnedVersionedData for PubSubMessage { - type Latest = v1::PubSubMessage; - - fn latest(latest: v1::PubSubMessage) -> Self { - PubSubMessage::V1(latest) - } - - fn into_latest(self) -> Result { - #[allow(irrefutable_let_patterns)] - if let PubSubMessage::V1(data) = self { - Ok(data) - } else { - bail!("version not latest"); - } - } - - fn deserialize_version(payload: &[u8], version: u16) -> Result { - match version { - 1 => Ok(PubSubMessage::V1(serde_bare::from_slice(payload)?)), - _ => bail!("invalid version: {version}"), - } - } - - fn serialize_version(self, _version: u16) -> Result> { - match self { - PubSubMessage::V1(data) => serde_bare::to_vec(&data).map_err(Into::into), - } - } -} - -impl PubSubMessage { - pub fn deserialize(buf: &[u8]) -> Result { - ::deserialize(buf, PROTOCOL_VERSION) - } - - pub fn serialize(self) -> Result> { - ::serialize(self, PROTOCOL_VERSION) - } -} diff --git a/sdks/schemas/runner-protocol/v1.bare b/sdks/schemas/runner-protocol/v1.bare index 37a6fa423b..1ee83cf5bb 100644 --- a/sdks/schemas/runner-protocol/v1.bare +++ b/sdks/schemas/runner-protocol/v1.bare @@ -167,6 +167,7 @@ type EventWrapper struct { } # MARK: Commands +# type CommandStartActor struct { actorId: Id generation: u32 @@ -188,8 +189,125 @@ type CommandWrapper struct { inner: Command } -# MARK: To Server +# MARK: Tunnel + +type RequestId data[16] # UUIDv4 +type MessageId data[16] # UUIDv4 + + +# Ack +type TunnelAck void + +# HTTP +type ToClientRequestStart struct { + actorId: Id + method: str + path: str + headers: map + body: optional + stream: bool +} + +type ToClientRequestChunk struct { + body: data + finish: bool +} + +type ToClientRequestAbort void + +type ToServerResponseStart struct { + status: u16 + headers: map + body: optional + stream: bool +} + +type ToServerResponseChunk struct { + body: data + finish: bool +} + +type ToServerResponseAbort void + +# WebSocket +type ToClientWebSocketOpen struct { + actorId: Id + path: str + headers: map +} + +type ToClientWebSocketMessage struct { + data: data + binary: bool +} + +type ToClientWebSocketClose struct { + code: optional + reason: optional +} + +type ToServerWebSocketOpen void + +type ToServerWebSocketMessage struct { + data: data + binary: bool +} + +type ToServerWebSocketClose struct { + code: optional + reason: optional +} + +# To Server +type ToServerTunnelMessageKind union { + TunnelAck | + + # HTTP + ToServerResponseStart | + ToServerResponseChunk | + ToServerResponseAbort | + + # WebSocket + ToServerWebSocketOpen | + ToServerWebSocketMessage | + ToServerWebSocketClose +} + +type ToServerTunnelMessage struct { + requestId: RequestId + messageId: MessageId + messageKind: ToServerTunnelMessageKind +} +# To Client +type ToClientTunnelMessageKind union { + TunnelAck | + + # HTTP + ToClientRequestStart | + ToClientRequestChunk | + ToClientRequestAbort | + + # WebSocket + ToClientWebSocketOpen | + ToClientWebSocketMessage | + ToClientWebSocketClose +} + +type ToClientTunnelMessage struct { + requestId: RequestId + messageId: MessageId + messageKind: ToClientTunnelMessageKind + + # Subject to send replies to. + # + # Only sent when opening a new request from gateway -> pegboard-runner-ws. + # + # Should be stripped before sending to the runner. + gatewayReplyTo: optional +} + +# MARK: To Server type ToServerInit struct { name: str version: u32 @@ -223,7 +341,8 @@ type ToServer union { ToServerAckCommands | ToServerStopping | ToServerPing | - ToServerKvRequest + ToServerKvRequest | + ToServerTunnelMessage } # MARK: To Client @@ -248,10 +367,19 @@ type ToClientKvResponse struct { data: KvResponseData } +type ToClientClose void + type ToClient union { ToClientInit | + ToClientClose | ToClientCommands | ToClientAckEvents | - ToClientKvResponse + ToClientKvResponse | + ToClientTunnelMessage +} + +# MARK: To Gateway +type ToGateway struct { + message: ToServerTunnelMessage } diff --git a/sdks/schemas/tunnel-protocol/v1.bare b/sdks/schemas/tunnel-protocol/v1.bare deleted file mode 100644 index f9e0e9c63f..0000000000 --- a/sdks/schemas/tunnel-protocol/v1.bare +++ /dev/null @@ -1,105 +0,0 @@ -type RequestId data[16] # UUIDv4 -type MessageId data[16] # UUIDv4 -type Id str - - -# MARK: Ack -type Ack void - -# MARK: HTTP -type ToServerRequestStart struct { - actorId: Id - method: str - path: str - headers: map - body: optional - stream: bool -} - -type ToServerRequestChunk struct { - body: data - finish: bool -} - -type ToServerRequestAbort void - -type ToClientResponseStart struct { - status: u16 - headers: map - body: optional - stream: bool -} - -type ToClientResponseChunk struct { - body: data - finish: bool -} - -type ToClientResponseAbort void - -# MARK: WebSocket -type ToServerWebSocketOpen struct { - actorId: Id - path: str - headers: map -} - -type ToServerWebSocketMessage struct { - data: data - binary: bool -} - -type ToServerWebSocketClose struct { - code: optional - reason: optional -} - -type ToClientWebSocketOpen void - -type ToClientWebSocketMessage struct { - data: data - binary: bool -} - -type ToClientWebSocketClose struct { - code: optional - reason: optional -} - -# MARK: Message -type MessageKind union { - Ack | - - # HTTP - ToServerRequestStart | - ToServerRequestChunk | - ToServerRequestAbort | - ToClientResponseStart | - ToClientResponseChunk | - ToClientResponseAbort | - - # WebSocket - ToServerWebSocketOpen | - ToServerWebSocketMessage | - ToServerWebSocketClose | - ToClientWebSocketOpen | - ToClientWebSocketMessage | - ToClientWebSocketClose -} - -# MARK: Message sent over tunnel WebSocket -type RunnerMessage struct { - requestId: RequestId - messageId: MessageId - messageKind: MessageKind -} - -# MARK: Message sent over UPS -type PubSubMessage struct { - requestId: RequestId - messageId: MessageId - # Subject to send replies to. Only sent when opening a new request from gateway -> runner. - replyTo: optional - messageKind: MessageKind -} - diff --git a/sdks/typescript/runner-protocol/src/index.ts b/sdks/typescript/runner-protocol/src/index.ts index 8c951c3a1f..9e43b2f095 100644 --- a/sdks/typescript/runner-protocol/src/index.ts +++ b/sdks/typescript/runner-protocol/src/index.ts @@ -1,8 +1,10 @@ +import assert from "node:assert" import * as bare from "@bare-ts/lib" const DEFAULT_CONFIG = /* @__PURE__ */ bare.Config({}) export type i64 = bigint +export type u16 = number export type u32 = number export type u64 = bigint @@ -803,9 +805,6 @@ export function writeEventWrapper(bc: bare.ByteCursor, x: EventWrapper): void { writeEvent(bc, x.inner) } -/** - * MARK: Commands - */ export type CommandStartActor = { readonly actorId: Id readonly generation: u32 @@ -894,7 +893,464 @@ export function writeCommandWrapper(bc: bare.ByteCursor, x: CommandWrapper): voi writeCommand(bc, x.inner) } -function read8(bc: bare.ByteCursor): ReadonlyMap { +export type RequestId = ArrayBuffer + +export function readRequestId(bc: bare.ByteCursor): RequestId { + return bare.readFixedData(bc, 16) +} + +export function writeRequestId(bc: bare.ByteCursor, x: RequestId): void { + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) +} + +/** + * UUIDv4 + */ +export type MessageId = ArrayBuffer + +export function readMessageId(bc: bare.ByteCursor): MessageId { + return bare.readFixedData(bc, 16) +} + +export function writeMessageId(bc: bare.ByteCursor, x: MessageId): void { + assert(x.byteLength === 16) + bare.writeFixedData(bc, x) +} + +/** + * Ack + */ +export type TunnelAck = null + +function read8(bc: bare.ByteCursor): ReadonlyMap { + const len = bare.readUintSafe(bc) + const result = new Map() + for (let i = 0; i < len; i++) { + const offset = bc.offset + const key = bare.readString(bc) + if (result.has(key)) { + bc.offset = offset + throw new bare.BareError(offset, "duplicated key") + } + result.set(key, bare.readString(bc)) + } + return result +} + +function write8(bc: bare.ByteCursor, x: ReadonlyMap): void { + bare.writeUintSafe(bc, x.size) + for (const kv of x) { + bare.writeString(bc, kv[0]) + bare.writeString(bc, kv[1]) + } +} + +/** + * HTTP + */ +export type ToClientRequestStart = { + readonly actorId: Id + readonly method: string + readonly path: string + readonly headers: ReadonlyMap + readonly body: ArrayBuffer | null + readonly stream: boolean +} + +export function readToClientRequestStart(bc: bare.ByteCursor): ToClientRequestStart { + return { + actorId: readId(bc), + method: bare.readString(bc), + path: bare.readString(bc), + headers: read8(bc), + body: read6(bc), + stream: bare.readBool(bc), + } +} + +export function writeToClientRequestStart(bc: bare.ByteCursor, x: ToClientRequestStart): void { + writeId(bc, x.actorId) + bare.writeString(bc, x.method) + bare.writeString(bc, x.path) + write8(bc, x.headers) + write6(bc, x.body) + bare.writeBool(bc, x.stream) +} + +export type ToClientRequestChunk = { + readonly body: ArrayBuffer + readonly finish: boolean +} + +export function readToClientRequestChunk(bc: bare.ByteCursor): ToClientRequestChunk { + return { + body: bare.readData(bc), + finish: bare.readBool(bc), + } +} + +export function writeToClientRequestChunk(bc: bare.ByteCursor, x: ToClientRequestChunk): void { + bare.writeData(bc, x.body) + bare.writeBool(bc, x.finish) +} + +export type ToClientRequestAbort = null + +export type ToServerResponseStart = { + readonly status: u16 + readonly headers: ReadonlyMap + readonly body: ArrayBuffer | null + readonly stream: boolean +} + +export function readToServerResponseStart(bc: bare.ByteCursor): ToServerResponseStart { + return { + status: bare.readU16(bc), + headers: read8(bc), + body: read6(bc), + stream: bare.readBool(bc), + } +} + +export function writeToServerResponseStart(bc: bare.ByteCursor, x: ToServerResponseStart): void { + bare.writeU16(bc, x.status) + write8(bc, x.headers) + write6(bc, x.body) + bare.writeBool(bc, x.stream) +} + +export type ToServerResponseChunk = { + readonly body: ArrayBuffer + readonly finish: boolean +} + +export function readToServerResponseChunk(bc: bare.ByteCursor): ToServerResponseChunk { + return { + body: bare.readData(bc), + finish: bare.readBool(bc), + } +} + +export function writeToServerResponseChunk(bc: bare.ByteCursor, x: ToServerResponseChunk): void { + bare.writeData(bc, x.body) + bare.writeBool(bc, x.finish) +} + +export type ToServerResponseAbort = null + +/** + * WebSocket + */ +export type ToClientWebSocketOpen = { + readonly actorId: Id + readonly path: string + readonly headers: ReadonlyMap +} + +export function readToClientWebSocketOpen(bc: bare.ByteCursor): ToClientWebSocketOpen { + return { + actorId: readId(bc), + path: bare.readString(bc), + headers: read8(bc), + } +} + +export function writeToClientWebSocketOpen(bc: bare.ByteCursor, x: ToClientWebSocketOpen): void { + writeId(bc, x.actorId) + bare.writeString(bc, x.path) + write8(bc, x.headers) +} + +export type ToClientWebSocketMessage = { + readonly data: ArrayBuffer + readonly binary: boolean +} + +export function readToClientWebSocketMessage(bc: bare.ByteCursor): ToClientWebSocketMessage { + return { + data: bare.readData(bc), + binary: bare.readBool(bc), + } +} + +export function writeToClientWebSocketMessage(bc: bare.ByteCursor, x: ToClientWebSocketMessage): void { + bare.writeData(bc, x.data) + bare.writeBool(bc, x.binary) +} + +function read9(bc: bare.ByteCursor): u16 | null { + return bare.readBool(bc) ? bare.readU16(bc) : null +} + +function write9(bc: bare.ByteCursor, x: u16 | null): void { + bare.writeBool(bc, x != null) + if (x != null) { + bare.writeU16(bc, x) + } +} + +export type ToClientWebSocketClose = { + readonly code: u16 | null + readonly reason: string | null +} + +export function readToClientWebSocketClose(bc: bare.ByteCursor): ToClientWebSocketClose { + return { + code: read9(bc), + reason: read5(bc), + } +} + +export function writeToClientWebSocketClose(bc: bare.ByteCursor, x: ToClientWebSocketClose): void { + write9(bc, x.code) + write5(bc, x.reason) +} + +export type ToServerWebSocketOpen = null + +export type ToServerWebSocketMessage = { + readonly data: ArrayBuffer + readonly binary: boolean +} + +export function readToServerWebSocketMessage(bc: bare.ByteCursor): ToServerWebSocketMessage { + return { + data: bare.readData(bc), + binary: bare.readBool(bc), + } +} + +export function writeToServerWebSocketMessage(bc: bare.ByteCursor, x: ToServerWebSocketMessage): void { + bare.writeData(bc, x.data) + bare.writeBool(bc, x.binary) +} + +export type ToServerWebSocketClose = { + readonly code: u16 | null + readonly reason: string | null +} + +export function readToServerWebSocketClose(bc: bare.ByteCursor): ToServerWebSocketClose { + return { + code: read9(bc), + reason: read5(bc), + } +} + +export function writeToServerWebSocketClose(bc: bare.ByteCursor, x: ToServerWebSocketClose): void { + write9(bc, x.code) + write5(bc, x.reason) +} + +/** + * To Server + */ +export type ToServerTunnelMessageKind = + | { readonly tag: "TunnelAck"; readonly val: TunnelAck } + /** + * HTTP + */ + | { readonly tag: "ToServerResponseStart"; readonly val: ToServerResponseStart } + | { readonly tag: "ToServerResponseChunk"; readonly val: ToServerResponseChunk } + | { readonly tag: "ToServerResponseAbort"; readonly val: ToServerResponseAbort } + /** + * WebSocket + */ + | { readonly tag: "ToServerWebSocketOpen"; readonly val: ToServerWebSocketOpen } + | { readonly tag: "ToServerWebSocketMessage"; readonly val: ToServerWebSocketMessage } + | { readonly tag: "ToServerWebSocketClose"; readonly val: ToServerWebSocketClose } + +export function readToServerTunnelMessageKind(bc: bare.ByteCursor): ToServerTunnelMessageKind { + const offset = bc.offset + const tag = bare.readU8(bc) + switch (tag) { + case 0: + return { tag: "TunnelAck", val: null } + case 1: + return { tag: "ToServerResponseStart", val: readToServerResponseStart(bc) } + case 2: + return { tag: "ToServerResponseChunk", val: readToServerResponseChunk(bc) } + case 3: + return { tag: "ToServerResponseAbort", val: null } + case 4: + return { tag: "ToServerWebSocketOpen", val: null } + case 5: + return { tag: "ToServerWebSocketMessage", val: readToServerWebSocketMessage(bc) } + case 6: + return { tag: "ToServerWebSocketClose", val: readToServerWebSocketClose(bc) } + default: { + bc.offset = offset + throw new bare.BareError(offset, "invalid tag") + } + } +} + +export function writeToServerTunnelMessageKind(bc: bare.ByteCursor, x: ToServerTunnelMessageKind): void { + switch (x.tag) { + case "TunnelAck": { + bare.writeU8(bc, 0) + break + } + case "ToServerResponseStart": { + bare.writeU8(bc, 1) + writeToServerResponseStart(bc, x.val) + break + } + case "ToServerResponseChunk": { + bare.writeU8(bc, 2) + writeToServerResponseChunk(bc, x.val) + break + } + case "ToServerResponseAbort": { + bare.writeU8(bc, 3) + break + } + case "ToServerWebSocketOpen": { + bare.writeU8(bc, 4) + break + } + case "ToServerWebSocketMessage": { + bare.writeU8(bc, 5) + writeToServerWebSocketMessage(bc, x.val) + break + } + case "ToServerWebSocketClose": { + bare.writeU8(bc, 6) + writeToServerWebSocketClose(bc, x.val) + break + } + } +} + +export type ToServerTunnelMessage = { + readonly requestId: RequestId + readonly messageId: MessageId + readonly messageKind: ToServerTunnelMessageKind +} + +export function readToServerTunnelMessage(bc: bare.ByteCursor): ToServerTunnelMessage { + return { + requestId: readRequestId(bc), + messageId: readMessageId(bc), + messageKind: readToServerTunnelMessageKind(bc), + } +} + +export function writeToServerTunnelMessage(bc: bare.ByteCursor, x: ToServerTunnelMessage): void { + writeRequestId(bc, x.requestId) + writeMessageId(bc, x.messageId) + writeToServerTunnelMessageKind(bc, x.messageKind) +} + +/** + * To Client + */ +export type ToClientTunnelMessageKind = + | { readonly tag: "TunnelAck"; readonly val: TunnelAck } + /** + * HTTP + */ + | { readonly tag: "ToClientRequestStart"; readonly val: ToClientRequestStart } + | { readonly tag: "ToClientRequestChunk"; readonly val: ToClientRequestChunk } + | { readonly tag: "ToClientRequestAbort"; readonly val: ToClientRequestAbort } + /** + * WebSocket + */ + | { readonly tag: "ToClientWebSocketOpen"; readonly val: ToClientWebSocketOpen } + | { readonly tag: "ToClientWebSocketMessage"; readonly val: ToClientWebSocketMessage } + | { readonly tag: "ToClientWebSocketClose"; readonly val: ToClientWebSocketClose } + +export function readToClientTunnelMessageKind(bc: bare.ByteCursor): ToClientTunnelMessageKind { + const offset = bc.offset + const tag = bare.readU8(bc) + switch (tag) { + case 0: + return { tag: "TunnelAck", val: null } + case 1: + return { tag: "ToClientRequestStart", val: readToClientRequestStart(bc) } + case 2: + return { tag: "ToClientRequestChunk", val: readToClientRequestChunk(bc) } + case 3: + return { tag: "ToClientRequestAbort", val: null } + case 4: + return { tag: "ToClientWebSocketOpen", val: readToClientWebSocketOpen(bc) } + case 5: + return { tag: "ToClientWebSocketMessage", val: readToClientWebSocketMessage(bc) } + case 6: + return { tag: "ToClientWebSocketClose", val: readToClientWebSocketClose(bc) } + default: { + bc.offset = offset + throw new bare.BareError(offset, "invalid tag") + } + } +} + +export function writeToClientTunnelMessageKind(bc: bare.ByteCursor, x: ToClientTunnelMessageKind): void { + switch (x.tag) { + case "TunnelAck": { + bare.writeU8(bc, 0) + break + } + case "ToClientRequestStart": { + bare.writeU8(bc, 1) + writeToClientRequestStart(bc, x.val) + break + } + case "ToClientRequestChunk": { + bare.writeU8(bc, 2) + writeToClientRequestChunk(bc, x.val) + break + } + case "ToClientRequestAbort": { + bare.writeU8(bc, 3) + break + } + case "ToClientWebSocketOpen": { + bare.writeU8(bc, 4) + writeToClientWebSocketOpen(bc, x.val) + break + } + case "ToClientWebSocketMessage": { + bare.writeU8(bc, 5) + writeToClientWebSocketMessage(bc, x.val) + break + } + case "ToClientWebSocketClose": { + bare.writeU8(bc, 6) + writeToClientWebSocketClose(bc, x.val) + break + } + } +} + +export type ToClientTunnelMessage = { + readonly requestId: RequestId + readonly messageId: MessageId + readonly messageKind: ToClientTunnelMessageKind + /** + * Should be stripped before sending to the runner. + */ + readonly gatewayReplyTo: string | null +} + +export function readToClientTunnelMessage(bc: bare.ByteCursor): ToClientTunnelMessage { + return { + requestId: readRequestId(bc), + messageId: readMessageId(bc), + messageKind: readToClientTunnelMessageKind(bc), + gatewayReplyTo: read5(bc), + } +} + +export function writeToClientTunnelMessage(bc: bare.ByteCursor, x: ToClientTunnelMessage): void { + writeRequestId(bc, x.requestId) + writeMessageId(bc, x.messageId) + writeToClientTunnelMessageKind(bc, x.messageKind) + write5(bc, x.gatewayReplyTo) +} + +function read10(bc: bare.ByteCursor): ReadonlyMap { const len = bare.readUintSafe(bc) const result = new Map() for (let i = 0; i < len; i++) { @@ -909,7 +1365,7 @@ function read8(bc: bare.ByteCursor): ReadonlyMap { return result } -function write8(bc: bare.ByteCursor, x: ReadonlyMap): void { +function write10(bc: bare.ByteCursor, x: ReadonlyMap): void { bare.writeUintSafe(bc, x.size) for (const kv of x) { bare.writeString(bc, kv[0]) @@ -917,28 +1373,31 @@ function write8(bc: bare.ByteCursor, x: ReadonlyMap): void { } } -function read9(bc: bare.ByteCursor): ReadonlyMap | null { - return bare.readBool(bc) ? read8(bc) : null +function read11(bc: bare.ByteCursor): ReadonlyMap | null { + return bare.readBool(bc) ? read10(bc) : null } -function write9(bc: bare.ByteCursor, x: ReadonlyMap | null): void { +function write11(bc: bare.ByteCursor, x: ReadonlyMap | null): void { bare.writeBool(bc, x != null) if (x != null) { - write8(bc, x) + write10(bc, x) } } -function read10(bc: bare.ByteCursor): Json | null { +function read12(bc: bare.ByteCursor): Json | null { return bare.readBool(bc) ? readJson(bc) : null } -function write10(bc: bare.ByteCursor, x: Json | null): void { +function write12(bc: bare.ByteCursor, x: Json | null): void { bare.writeBool(bc, x != null) if (x != null) { writeJson(bc, x) } } +/** + * MARK: To Server + */ export type ToServerInit = { readonly name: string readonly version: u32 @@ -954,8 +1413,8 @@ export function readToServerInit(bc: bare.ByteCursor): ToServerInit { version: bare.readU32(bc), totalSlots: bare.readU32(bc), lastCommandIdx: read7(bc), - prepopulateActorNames: read9(bc), - metadata: read10(bc), + prepopulateActorNames: read11(bc), + metadata: read12(bc), } } @@ -964,8 +1423,8 @@ export function writeToServerInit(bc: bare.ByteCursor, x: ToServerInit): void { bare.writeU32(bc, x.version) bare.writeU32(bc, x.totalSlots) write7(bc, x.lastCommandIdx) - write9(bc, x.prepopulateActorNames) - write10(bc, x.metadata) + write11(bc, x.prepopulateActorNames) + write12(bc, x.metadata) } export type ToServerEvents = readonly EventWrapper[] @@ -1046,6 +1505,7 @@ export type ToServer = | { readonly tag: "ToServerStopping"; readonly val: ToServerStopping } | { readonly tag: "ToServerPing"; readonly val: ToServerPing } | { readonly tag: "ToServerKvRequest"; readonly val: ToServerKvRequest } + | { readonly tag: "ToServerTunnelMessage"; readonly val: ToServerTunnelMessage } export function readToServer(bc: bare.ByteCursor): ToServer { const offset = bc.offset @@ -1063,6 +1523,8 @@ export function readToServer(bc: bare.ByteCursor): ToServer { return { tag: "ToServerPing", val: readToServerPing(bc) } case 5: return { tag: "ToServerKvRequest", val: readToServerKvRequest(bc) } + case 6: + return { tag: "ToServerTunnelMessage", val: readToServerTunnelMessage(bc) } default: { bc.offset = offset throw new bare.BareError(offset, "invalid tag") @@ -1101,6 +1563,11 @@ export function writeToServer(bc: bare.ByteCursor, x: ToServer): void { writeToServerKvRequest(bc, x.val) break } + case "ToServerTunnelMessage": { + bare.writeU8(bc, 6) + writeToServerTunnelMessage(bc, x.val) + break + } } } @@ -1212,11 +1679,15 @@ export function writeToClientKvResponse(bc: bare.ByteCursor, x: ToClientKvRespon writeKvResponseData(bc, x.data) } +export type ToClientClose = null + export type ToClient = | { readonly tag: "ToClientInit"; readonly val: ToClientInit } + | { readonly tag: "ToClientClose"; readonly val: ToClientClose } | { readonly tag: "ToClientCommands"; readonly val: ToClientCommands } | { readonly tag: "ToClientAckEvents"; readonly val: ToClientAckEvents } | { readonly tag: "ToClientKvResponse"; readonly val: ToClientKvResponse } + | { readonly tag: "ToClientTunnelMessage"; readonly val: ToClientTunnelMessage } export function readToClient(bc: bare.ByteCursor): ToClient { const offset = bc.offset @@ -1225,11 +1696,15 @@ export function readToClient(bc: bare.ByteCursor): ToClient { case 0: return { tag: "ToClientInit", val: readToClientInit(bc) } case 1: - return { tag: "ToClientCommands", val: readToClientCommands(bc) } + return { tag: "ToClientClose", val: null } case 2: - return { tag: "ToClientAckEvents", val: readToClientAckEvents(bc) } + return { tag: "ToClientCommands", val: readToClientCommands(bc) } case 3: + return { tag: "ToClientAckEvents", val: readToClientAckEvents(bc) } + case 4: return { tag: "ToClientKvResponse", val: readToClientKvResponse(bc) } + case 5: + return { tag: "ToClientTunnelMessage", val: readToClientTunnelMessage(bc) } default: { bc.offset = offset throw new bare.BareError(offset, "invalid tag") @@ -1244,21 +1719,30 @@ export function writeToClient(bc: bare.ByteCursor, x: ToClient): void { writeToClientInit(bc, x.val) break } - case "ToClientCommands": { + case "ToClientClose": { bare.writeU8(bc, 1) + break + } + case "ToClientCommands": { + bare.writeU8(bc, 2) writeToClientCommands(bc, x.val) break } case "ToClientAckEvents": { - bare.writeU8(bc, 2) + bare.writeU8(bc, 3) writeToClientAckEvents(bc, x.val) break } case "ToClientKvResponse": { - bare.writeU8(bc, 3) + bare.writeU8(bc, 4) writeToClientKvResponse(bc, x.val) break } + case "ToClientTunnelMessage": { + bare.writeU8(bc, 5) + writeToClientTunnelMessage(bc, x.val) + break + } } } @@ -1280,3 +1764,39 @@ export function decodeToClient(bytes: Uint8Array): ToClient { } return result } + +/** + * MARK: To Gateway + */ +export type ToGateway = { + readonly message: ToServerTunnelMessage +} + +export function readToGateway(bc: bare.ByteCursor): ToGateway { + return { + message: readToServerTunnelMessage(bc), + } +} + +export function writeToGateway(bc: bare.ByteCursor, x: ToGateway): void { + writeToServerTunnelMessage(bc, x.message) +} + +export function encodeToGateway(x: ToGateway, config?: Partial): Uint8Array { + const fullConfig = config != null ? bare.Config(config) : DEFAULT_CONFIG + const bc = new bare.ByteCursor( + new Uint8Array(fullConfig.initialBufferLength), + fullConfig, + ) + writeToGateway(bc, x) + return new Uint8Array(bc.view.buffer, bc.view.byteOffset, bc.offset) +} + +export function decodeToGateway(bytes: Uint8Array): ToGateway { + const bc = new bare.ByteCursor(bytes, DEFAULT_CONFIG) + const result = readToGateway(bc) + if (bc.offset < bc.view.byteLength) { + throw new bare.BareError(bc.offset, "remaining bytes") + } + return result +}