diff --git a/packages/core/guard/core/src/custom_serve.rs b/packages/core/guard/core/src/custom_serve.rs index 99149e273e..0108545bb4 100644 --- a/packages/core/guard/core/src/custom_serve.rs +++ b/packages/core/guard/core/src/custom_serve.rs @@ -2,10 +2,10 @@ use anyhow::*; use async_trait::async_trait; use bytes::Bytes; use http_body_util::Full; -use hyper::body::Incoming as BodyIncoming; use hyper::{Request, Response}; use hyper_tungstenite::HyperWebsocket; +use crate::WebSocketHandle; use crate::proxy_service::ResponseBody; use crate::request_context::RequestContext; @@ -19,19 +19,12 @@ pub trait CustomServeTrait: Send + Sync { request_context: &mut RequestContext, ) -> Result>; - /// Handle a WebSocket connection after upgrade. - /// - /// Contract for retries: - /// - Return `Ok(())` after you have accepted (`await`ed) the client websocket and - /// completed the streaming lifecycle. No further retries are possible. - /// - Return `Err((client_ws, err))` if you have NOT accepted the websocket yet and - /// want the proxy to optionally re-resolve and retry with a different handler. - /// You must not `await` the websocket before returning this error. + /// Handle a WebSocket connection after upgrade. Supports connection retries. async fn handle_websocket( &self, - client_ws: HyperWebsocket, + websocket: WebSocketHandle, headers: &hyper::HeaderMap, path: &str, request_context: &mut RequestContext, - ) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)>; + ) -> Result<()>; } diff --git a/packages/core/guard/core/src/lib.rs b/packages/core/guard/core/src/lib.rs index 6f090d7489..35207a8496 100644 --- a/packages/core/guard/core/src/lib.rs +++ b/packages/core/guard/core/src/lib.rs @@ -7,12 +7,14 @@ pub mod proxy_service; pub mod request_context; mod server; pub mod types; +pub mod websocket_handle; pub use cert_resolver::CertResolverFn; pub use custom_serve::CustomServeTrait; pub use proxy_service::{ CacheKeyFn, MiddlewareFn, ProxyService, ProxyState, RouteTarget, RoutingFn, RoutingOutput, }; +pub use websocket_handle::WebSocketHandle; // Re-export hyper StatusCode for use in other crates pub mod status { diff --git a/packages/core/guard/core/src/proxy_service.rs b/packages/core/guard/core/src/proxy_service.rs index 80cfc7cb80..0cd6948709 100644 --- a/packages/core/guard/core/src/proxy_service.rs +++ b/packages/core/guard/core/src/proxy_service.rs @@ -1,11 +1,3 @@ -use std::{ - borrow::Cow, - collections::HashMap as StdHashMap, - net::SocketAddr, - sync::Arc, - time::{Duration, Instant}, -}; - use anyhow::*; use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; @@ -16,17 +8,30 @@ use hyper_util::{client::legacy::Client, rt::TokioExecutor}; use moka::future::Cache; use rand; use rivet_api_builder::{ErrorResponse, RawErrorResponse}; -use rivet_error::RivetError; +use rivet_error::{INTERNAL_ERROR, RivetError}; use rivet_metrics::KeyValue; use rivet_util::Id; use serde_json; +use std::{ + borrow::Cow, + collections::HashMap as StdHashMap, + net::SocketAddr, + sync::Arc, + time::{Duration, Instant}, +}; use tokio::sync::Mutex; use tokio::time::timeout; -use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::{ + client::IntoClientRequest, + protocol::{CloseFrame, frame::coding::CloseCode}, +}; use tracing::Instrument; use url::Url; -use crate::{custom_serve::CustomServeTrait, errors, metrics, request_context::RequestContext}; +use crate::{ + WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics, + request_context::RequestContext, +}; pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error"); @@ -1432,9 +1437,9 @@ impl ProxyService { // Close the WebSocket connection with the response message let _ = client_ws.close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame { - code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: response.message.as_ref().into(), - })).await; + code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, + reason: response.message.as_ref().into(), + })).await; return; } Result::Ok(ResolveRouteOutput::CustomServe(_)) => { @@ -1813,31 +1818,42 @@ impl ProxyService { let mut attempts = 0u32; let mut client_ws = client_websocket; + let ws_handle = WebSocketHandle::new(client_ws); + loop { match handlers .handle_websocket( - client_ws, + ws_handle.clone(), &req_headers, &req_path, &mut request_context, ) .await { - Result::Ok(()) => break, - Result::Err((returned_client_ws, err)) => { + Result::Ok(()) => { + tracing::debug!("websocket closed"); + + // Send graceful close + ws_handle.send(hyper_tungstenite::tungstenite::Message::Close(Some( + hyper_tungstenite::tungstenite::protocol::CloseFrame { + code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal, + reason: format!("Closed").into(), + }, + ))); + + break; + } + Result::Err(err) => { attempts += 1; if attempts > max_attempts || !is_retryable_ws_error(&err) { - // Accept and close the client websocket with an error reason - if let Result::Ok(mut ws) = returned_client_ws.await { - let _ = ws - .send(hyper_tungstenite::tungstenite::Message::Close(Some( - hyper_tungstenite::tungstenite::protocol::CloseFrame { - code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: format!("{}", err).into(), - }, - ))) - .await; - } + // Close WebSocket with error + ws_handle + .accept_and_send( + hyper_tungstenite::tungstenite::Message::Close( + Some(err_to_close_frame(err)), + ), + ) + .await?; break; } else { @@ -1861,49 +1877,38 @@ impl ProxyService { new_handlers, )) => { handlers = new_handlers; - client_ws = returned_client_ws; continue; } Result::Ok(ResolveRouteOutput::Response(response)) => { - if let Result::Ok(mut ws) = returned_client_ws.await - { - let _ = ws - .send(hyper_tungstenite::tungstenite::Message::Close(Some( - hyper_tungstenite::tungstenite::protocol::CloseFrame { - code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: response.message.as_ref().into(), - }, - ))) - .await; - } - break; + ws_handle + .accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some( + hyper_tungstenite::tungstenite::protocol::CloseFrame { + code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, + reason: response.message.as_ref().into(), + }, + ))) + .await; } Result::Ok(ResolveRouteOutput::Target(_)) => { - if let Result::Ok(mut ws) = returned_client_ws.await - { - let _ = ws - .send(hyper_tungstenite::tungstenite::Message::Close(Some( - hyper_tungstenite::tungstenite::protocol::CloseFrame { - code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: "Cannot retry WebSocket with non-custom serve route".into(), - }, - ))) - .await; - } + ws_handle + .accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some( + hyper_tungstenite::tungstenite::protocol::CloseFrame { + code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, + reason: "Cannot retry WebSocket with non-custom serve route".into(), + }, + ))) + .await; break; } Err(res_err) => { - if let Result::Ok(mut ws) = returned_client_ws.await - { - let _ = ws - .send(hyper_tungstenite::tungstenite::Message::Close(Some( - hyper_tungstenite::tungstenite::protocol::CloseFrame { - code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, - reason: format!("Routing error: {}", res_err).into(), - }, - ))) - .await; - } + ws_handle + .accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some( + hyper_tungstenite::tungstenite::protocol::CloseFrame { + code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error, + reason: format!("Routing error: {}", res_err).into(), + }, + ))) + .await; break; } } @@ -2242,3 +2247,26 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool { false } } + +pub fn err_to_close_frame(err: anyhow::Error) -> CloseFrame { + let rivet_err = err + .chain() + .find_map(|x| x.downcast_ref::()) + .cloned() + .unwrap_or_else(|| RivetError::from(&INTERNAL_ERROR)); + + let code = match (rivet_err.group(), rivet_err.code()) { + ("ws", "connection_closed") => CloseCode::Normal, + _ => CloseCode::Error, + }; + + // NOTE: reason cannot be more than 123 bytes as per the WS protocol + let reason = rivet_util::safe_slice( + &format!("{}.{}", rivet_err.group(), rivet_err.code()), + 0, + 123, + ) + .into(); + + CloseFrame { code, reason } +} diff --git a/packages/core/guard/core/src/websocket_handle.rs b/packages/core/guard/core/src/websocket_handle.rs new file mode 100644 index 0000000000..98c3a756cb --- /dev/null +++ b/packages/core/guard/core/src/websocket_handle.rs @@ -0,0 +1,104 @@ +use anyhow::*; +use futures_util::{SinkExt, StreamExt}; +use hyper::upgrade::Upgraded; +use hyper_tungstenite::HyperWebsocket; +use hyper_tungstenite::tungstenite::Message as WsMessage; +use hyper_util::rt::TokioIo; +use std::ops::Deref; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio_tungstenite::WebSocketStream; + +pub type WebSocketReceiver = futures_util::stream::SplitStream>>; + +pub type WebSocketSender = + futures_util::stream::SplitSink>, WsMessage>; + +enum WebSocketState { + Unaccepted { websocket: HyperWebsocket }, + Accepting, + Split { ws_tx: WebSocketSender }, +} + +#[derive(Clone)] +pub struct WebSocketHandle(Arc); + +impl WebSocketHandle { + pub fn new(websocket: HyperWebsocket) -> Self { + Self(Arc::new(WebSocketHandleInner { + state: Mutex::new(WebSocketState::Unaccepted { websocket }), + })) + } +} + +impl Deref for WebSocketHandle { + type Target = WebSocketHandleInner; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +pub struct WebSocketHandleInner { + state: Mutex, +} + +impl WebSocketHandleInner { + pub async fn accept(&self) -> Result { + let mut state = self.state.lock().await; + Self::accept_inner(&mut *state).await + } + + pub async fn send(&self, message: WsMessage) -> Result<()> { + let mut state = self.state.lock().await; + match &mut *state { + WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => { + bail!("websocket has not been accepted") + } + WebSocketState::Split { ws_tx } => { + ws_tx.send(message).await?; + Ok(()) + } + } + } + + pub async fn accept_and_send(&self, message: WsMessage) -> Result<()> { + let mut state = self.state.lock().await; + match &mut *state { + WebSocketState::Unaccepted { .. } => { + let _ = Self::accept_inner(&mut *state).await?; + let WebSocketState::Split { ws_tx } = &mut *state else { + bail!("websocket should be accepted"); + }; + ws_tx.send(message).await?; + Ok(()) + } + WebSocketState::Accepting => { + bail!("in accepting state") + } + WebSocketState::Split { ws_tx } => { + ws_tx.send(message).await?; + Ok(()) + } + } + } + + async fn accept_inner(state: &mut WebSocketState) -> Result { + if !matches!(*state, WebSocketState::Unaccepted { .. }) { + bail!("websocket already accepted") + } + + // Accept websocket + let old_state = std::mem::replace(&mut *state, WebSocketState::Accepting); + let WebSocketState::Unaccepted { websocket } = old_state else { + bail!("should be in unaccepted state"); + }; + + // Accept WS + let ws_stream = websocket.await?; + let (ws_tx, ws_rx) = ws_stream.split(); + *state = WebSocketState::Split { ws_tx }; + + Ok(ws_rx) + } +} diff --git a/packages/core/guard/server/src/routing/api_public.rs b/packages/core/guard/server/src/routing/api_public.rs index 0f5a2ab41b..a762aca276 100644 --- a/packages/core/guard/server/src/routing/api_public.rs +++ b/packages/core/guard/server/src/routing/api_public.rs @@ -5,9 +5,9 @@ use async_trait::async_trait; use bytes::Bytes; use gas::prelude::*; use http_body_util::{BodyExt, Full}; -use hyper::body::Incoming as BodyIncoming; use hyper::{Request, Response}; use hyper_tungstenite::HyperWebsocket; +use rivet_guard_core::WebSocketHandle; use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput}; use rivet_guard_core::{CustomServeTrait, request_context::RequestContext}; use tower::Service; @@ -47,15 +47,12 @@ impl CustomServeTrait for ApiPublicService { async fn handle_websocket( &self, - client_ws: HyperWebsocket, + _client_ws: WebSocketHandle, _headers: &hyper::HeaderMap, _path: &str, _request_context: &mut RequestContext, - ) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)> { - Err(( - client_ws, - anyhow::anyhow!("api-public does not support WebSocket connections"), - )) + ) -> Result<()> { + bail!("api-public does not support WebSocket connections") } } diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index f8972c71a6..eef00f0a6c 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -1,13 +1,14 @@ use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; -use futures_util::{SinkExt, StreamExt}; +use futures_util::StreamExt; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response, StatusCode}; -use hyper_tungstenite::HyperWebsocket; use rivet_guard_core::{ + WebSocketHandle, custom_serve::CustomServeTrait, + errors::WebSocketServiceUnavailable, proxy_service::{ResponseBody, X_RIVET_ERROR}, request_context::RequestContext, }; @@ -66,24 +67,21 @@ impl CustomServeTrait for PegboardGateway { async fn handle_websocket( &self, - client_ws: HyperWebsocket, + client_ws: WebSocketHandle, headers: &hyper::HeaderMap, path: &str, _request_context: &mut RequestContext, - ) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)> { - match self + ) -> Result<()> { + let res = self .handle_websocket_inner(client_ws, headers, path, _request_context) - .await - { - Result::Ok(()) => std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()), - Result::Err((client_ws, err)) => { + .await; + match res { + Result::Ok(x) => Ok(x), + Err(err) => { if is_tunnel_service_unavailable(&err) { - Err(( - client_ws, - rivet_guard_core::errors::WebSocketServiceUnavailable.build(), - )) + Err(WebSocketServiceUnavailable.build()) } else { - Err((client_ws, err)) + Err(err) } } } @@ -197,20 +195,17 @@ impl PegboardGateway { async fn handle_websocket_inner( &self, - client_ws: HyperWebsocket, + client_ws: WebSocketHandle, headers: &hyper::HeaderMap, path: &str, _request_context: &mut RequestContext, - ) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)> { + ) -> Result<()> { // Extract actor ID for the message - let actor_id = match headers + let actor_id = headers .get("x-rivet-actor") - .context("missing x-rivet-actor header") - .and_then(|v| v.to_str().context("invalid x-rivet-actor header")) - { - Result::Ok(v) => v.to_string(), - Err(err) => return Err((client_ws, err)), - }; + .context("missing x-rivet-actor")? + .to_str()? + .to_string(); // Extract headers let mut request_headers = HashableMap::new(); @@ -239,19 +234,14 @@ impl PegboardGateway { }, ); - if let Err(err) = self - .shared_state + self.shared_state .send_message(request_id, open_message) - .await - { - return Err((client_ws, err)); - } + .await?; // Wait for WebSocket open acknowledgment let open_ack_received = loop { let Some(msg) = msg_rx.recv().await else { - tracing::warn!("received no websocket open response"); - return Err((client_ws, RequestError::ServiceUnavailable.into())); + bail!("received no websocket open response"); }; match msg { @@ -263,12 +253,10 @@ impl PegboardGateway { TunnelMessageData::Message( protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close), ) => { - tracing::info!(?close, "websocket closed before opening"); - return Err((client_ws, RequestError::ServiceUnavailable.into())); + bail!("websocket closed before opening: {close:?}"); } TunnelMessageData::Timeout => { - tracing::warn!("websocket open timeout"); - return Err((client_ws, RequestError::ServiceUnavailable.into())); + bail!("websocket open timeout"); } _ => { tracing::warn!("received unexpected message while waiting for websocket open"); @@ -277,19 +265,11 @@ impl PegboardGateway { }; if !open_ack_received { - return Err((client_ws, anyhow!("failed to open websocket"))); + bail!("failed to open websocket"); } // Accept the WebSocket - let ws_stream = match client_ws.await { - Result::Ok(ws) => ws, - Err(e) => { - // Handshake already in progress; cannot retry safely here - tracing::debug!(error = ?e, "client websocket await failed"); - return std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()); - } - }; - let (mut ws_sink, mut ws_stream) = ws_stream.split(); + let mut ws_rx = client_ws.accept().await?; // Spawn task to forward messages from server to client let mut msg_rx_for_task = msg_rx; @@ -304,7 +284,7 @@ impl PegboardGateway { } else { Message::Text(String::from_utf8_lossy(&ws_msg.data).into_owned().into()) }; - if let Err(e) = ws_sink.send(msg).await { + if let Err(e) = client_ws.send(msg).await { tracing::warn!(?e, "failed to send websocket message to client"); break; } @@ -326,7 +306,7 @@ impl PegboardGateway { // Forward messages from client to server let mut close_reason = None; - while let Some(msg) = ws_stream.next().await { + while let Some(msg) = ws_rx.next().await { match msg { Result::Ok(Message::Binary(data)) => { let ws_message = protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( @@ -387,7 +367,7 @@ impl PegboardGateway { } } - std::result::Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()) + Ok(()) } } diff --git a/packages/core/pegboard-runner/src/client_to_pubsub_task.rs b/packages/core/pegboard-runner/src/client_to_pubsub_task.rs index 8d62c1f5e9..643fc1f495 100644 --- a/packages/core/pegboard-runner/src/client_to_pubsub_task.rs +++ b/packages/core/pegboard-runner/src/client_to_pubsub_task.rs @@ -1,9 +1,11 @@ +use anyhow::Context; use futures_util::{SinkExt, StreamExt}; use gas::prelude::Id; use gas::prelude::*; use hyper_tungstenite::tungstenite::Message as WsMessage; use hyper_tungstenite::tungstenite::Message; use pegboard_actor_kv as kv; +use rivet_guard_core::websocket_handle::WebSocketReceiver; use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; use std::sync::{Arc, atomic::Ordering}; use universalpubsub::PublishOpts; @@ -11,7 +13,7 @@ use versioned_data_util::OwnedVersionedData as _; use crate::{ conn::Conn, - utils::{self, WebSocketReceiver}, + utils::{self}, }; pub async fn task(ctx: StandaloneCtx, conn: Arc, ws_rx: WebSocketReceiver) { @@ -44,12 +46,18 @@ async fn task_inner( { Result::Ok(x) => x, Err(err) => { - tracing::error!(?err, "failed to deserialize message"); + tracing::warn!( + ?err, + data_len = data.len(), + "failed to deserialize message" + ); continue; } }; - handle_message(&ctx, &conn, msg).await?; + handle_message(&ctx, &conn, msg) + .await + .context("failed to handle WebSocket message")?; } Result::Ok(WsMessage::Close(_)) => { tracing::info!(?conn.runner_id, "WebSocket closed"); @@ -76,7 +84,10 @@ async fn handle_message( ) -> Result<()> { match msg { protocol::ToServer::ToServerPing(ping) => { - let rtt = util::timestamp::now().saturating_sub(ping.ts).try_into()?; + let rtt = util::timestamp::now() + .saturating_sub(ping.ts) + .try_into() + .context("failed to calculate RTT from ping timestamp")?; conn.last_rtt.store(rtt, Ordering::Relaxed); } @@ -96,12 +107,13 @@ async fn handle_message( }), ); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV error response")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV error response to client")?; return Ok(()); } @@ -111,7 +123,8 @@ async fn handle_message( .op(pegboard::ops::actor::get_runner::Input { actor_ids: vec![actor_id], }) - .await?; + .await + .with_context(|| format!("failed to get runner for actor: {}", actor_id))?; let actor_belongs = actors_res .actors .first() @@ -131,12 +144,13 @@ async fn handle_message( }, )); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV actor validation error")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV actor validation error to client")?; return Ok(()); } @@ -170,12 +184,13 @@ async fn handle_message( }), ); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV get response")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV get response to client")?; } protocol::KvRequestData::KvListRequest(body) => { let res = kv::list( @@ -183,7 +198,10 @@ async fn handle_message( actor_id, body.query, body.reverse.unwrap_or_default(), - body.limit.map(TryInto::try_into).transpose()?, + body.limit + .map(TryInto::try_into) + .transpose() + .context("KV list limit value overflow")?, ) .await; @@ -210,12 +228,13 @@ async fn handle_message( }), ); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV list response")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV list response to client")?; } protocol::KvRequestData::KvPutRequest(body) => { let res = kv::put(&*ctx.udb()?, actor_id, body.keys, body.values).await; @@ -237,12 +256,13 @@ async fn handle_message( }), ); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV put response")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV put response to client")?; } protocol::KvRequestData::KvDeleteRequest(body) => { let res = kv::delete(&*ctx.udb()?, actor_id, body.keys).await; @@ -262,12 +282,13 @@ async fn handle_message( }), ); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV delete response")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV delete response to client")?; } protocol::KvRequestData::KvDropRequest => { let res = kv::delete_all(&*ctx.udb()?, actor_id).await; @@ -287,17 +308,20 @@ async fn handle_message( }), ); - let res_msg_serialized = res_msg.serialize(conn.protocol_version)?; - conn.ws_tx - .lock() - .await + let res_msg_serialized = res_msg + .serialize(conn.protocol_version) + .context("failed to serialize KV drop response")?; + conn.ws_handle .send(Message::Binary(res_msg_serialized.into())) - .await?; + .await + .context("failed to send KV drop response to client")?; } } } protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { - handle_tunnel_message(&ctx, &conn, tunnel_msg).await?; + handle_tunnel_message(&ctx, &conn, tunnel_msg) + .await + .context("failed to handle tunnel message")?; } // Forward to runner wf protocol::ToServer::ToServerInit(_) @@ -305,11 +329,15 @@ async fn handle_message( | protocol::ToServer::ToServerAckCommands(_) | protocol::ToServer::ToServerStopping => { ctx.signal(pegboard::workflows::runner::Forward { - inner: protocol::ToServer::try_from(msg)?, + inner: protocol::ToServer::try_from(msg) + .context("failed to convert message for workflow forwarding")?, }) .to_workflow_id(conn.workflow_id) .send() - .await?; + .await + .with_context(|| { + format!("failed to forward signal to workflow: {}", conn.workflow_id) + })?; } } @@ -341,10 +369,18 @@ async fn handle_tunnel_message( // Publish message to UPS let msg_serialized = versioned::ToGateway::latest(protocol::ToGateway { message: msg }) - .serialize_with_embedded_version(PROTOCOL_VERSION)?; - ctx.ups()? + .serialize_with_embedded_version(PROTOCOL_VERSION) + .context("failed to serialize tunnel message for gateway")?; + ctx.ups() + .context("failed to get UPS instance for tunnel message")? .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) - .await?; + .await + .with_context(|| { + format!( + "failed to publish tunnel message to gateway reply topic: {}", + gateway_reply_to + ) + })?; Ok(()) } diff --git a/packages/core/pegboard-runner/src/conn.rs b/packages/core/pegboard-runner/src/conn.rs index 08840cccf6..aa11e3ddac 100644 --- a/packages/core/pegboard-runner/src/conn.rs +++ b/packages/core/pegboard-runner/src/conn.rs @@ -1,8 +1,10 @@ +use anyhow::Context; use futures_util::StreamExt; use gas::prelude::Id; use gas::prelude::*; use hyper_tungstenite::tungstenite::Message; use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; +use rivet_guard_core::{WebSocketHandle, websocket_handle::WebSocketReceiver}; use rivet_runner_protocol as protocol; use rivet_runner_protocol::*; use std::{ @@ -13,10 +15,7 @@ use std::{ use tokio::sync::Mutex; use versioned_data_util::OwnedVersionedData as _; -use crate::{ - errors::WsError, - utils::{UrlData, WebSocketReceiver, WebSocketSender}, -}; +use crate::{errors::WsError, utils::UrlData}; pub struct TunnelActiveRequest { /// Subject to send replies to. @@ -30,15 +29,7 @@ pub struct Conn { pub protocol_version: u16, - pub ws_tx: Arc< - Mutex< - Box< - dyn futures_util::Sink - + Send - + Unpin, - >, - >, - >, + pub ws_handle: WebSocketHandle, pub last_rtt: AtomicU32, @@ -50,7 +41,7 @@ pub struct Conn { #[tracing::instrument(skip_all)] pub async fn init_conn( ctx: &StandaloneCtx, - ws_tx: &mut Option, + ws_handle: WebSocketHandle, ws_rx: &mut WebSocketReceiver, UrlData { protocol_version, @@ -58,10 +49,13 @@ pub async fn init_conn( runner_key, }: UrlData, ) -> Result> { + let namespace_name = namespace.clone(); let namespace = ctx .op(namespace::ops::resolve_for_name_global::Input { name: namespace }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; + .await + .with_context(|| format!("failed to resolve namespace: {}", namespace_name))? + .ok_or_else(|| namespace::errors::Namespace::NotFound.build()) + .with_context(|| format!("namespace not found: {}", namespace_name))?; tracing::debug!("new runner connection"); @@ -81,7 +75,8 @@ pub async fn init_conn( }; let packet = versioned::ToServer::deserialize(&buf, protocol_version) - .map_err(|err| WsError::InvalidPacket(err.to_string()).build())?; + .map_err(|err| WsError::InvalidPacket(err.to_string()).build()) + .context("failed to deserialize initial packet from client")?; let (runner_id, workflow_id) = if let protocol::ToServer::ToServerInit(protocol::ToServerInit { @@ -98,7 +93,13 @@ pub async fn init_conn( name: name.clone(), key: runner_key.clone(), }) - .await?; + .await + .with_context(|| { + format!( + "failed to get existing runner by key: {}:{}", + name, runner_key + ) + })?; let runner_id = if let Some(runner) = existing_runner.runner { // IMPORTANT: Before we spawn/get the workflow, we try to update the runner's last ping ts. @@ -112,7 +113,10 @@ pub async fn init_conn( action: Action::UpdatePing { rtt: 0 }, }], }) - .await?; + .await + .with_context(|| { + format!("failed to update ping for runner: {}", runner.runner_id) + })?; if update_ping_res .notifications @@ -145,7 +149,13 @@ pub async fn init_conn( .tag("runner_id", runner_id) .unique() .dispatch() - .await?; + .await + .with_context(|| { + format!( + "failed to dispatch runner workflow for runner: {}", + runner_id + ) + })?; (runner_id, workflow_id) } else { @@ -157,25 +167,24 @@ pub async fn init_conn( ctx.signal(pegboard::workflows::runner::Forward { inner: packet }) .to_workflow_id(workflow_id) .send() - .await?; + .await + .with_context(|| { + format!( + "failed to forward initial packet to workflow: {}", + workflow_id + ) + })?; (runner_id, workflow_id) } else { return Err(WsError::ConnectionClosed.build()); }; - let tx = ws_tx.take().context("should exist")?; - Ok(Arc::new(Conn { runner_id, workflow_id, protocol_version, - ws_tx: Arc::new(Mutex::new(Box::new(tx) - as Box< - dyn futures_util::Sink - + Send - + Unpin, - >)), + ws_handle, last_rtt: AtomicU32::new(0), tunnel_active_requests: Mutex::new(HashMap::new()), })) diff --git a/packages/core/pegboard-runner/src/lib.rs b/packages/core/pegboard-runner/src/lib.rs index e5e256a7d9..43703028b5 100644 --- a/packages/core/pegboard-runner/src/lib.rs +++ b/packages/core/pegboard-runner/src/lib.rs @@ -1,13 +1,15 @@ +use anyhow::Context; use async_trait::async_trait; use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use gas::prelude::*; use http_body_util::Full; use hyper::{Response, StatusCode}; -use hyper_tungstenite::{HyperWebsocket, tungstenite::Message}; +use hyper_tungstenite::tungstenite::Message; use pegboard::ops::runner::update_alloc_idx::Action; use rivet_guard_core::{ - custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, + WebSocketHandle, custom_serve::CustomServeTrait, proxy_service::ResponseBody, + request_context::RequestContext, }; use std::time::Duration; @@ -53,118 +55,80 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { async fn handle_websocket( &self, - client_ws: HyperWebsocket, + ws_handle: WebSocketHandle, _headers: &hyper::HeaderMap, path: &str, _request_context: &mut RequestContext, - ) -> Result<(), (HyperWebsocket, anyhow::Error)> { - // TODO: Spawn ping thread - // TODO: Spawn message thread - // TODO: Create conn - + ) -> Result<()> { // Get UPS - let ups = match self.ctx.ups() { - Ok(x) => x, - Err(err) => { - tracing::warn!(?err, "could not get ups"); - return Err((client_ws, err)); - } - }; + let ups = self.ctx.ups().context("failed to get UPS instance")?; // Parse URL to extract parameters - let url = match url::Url::parse(&format!("ws://placeholder/{path}")) { - Result::Ok(u) => u, - Result::Err(e) => return Err((client_ws, e.into())), - }; - - let url_data = match utils::UrlData::parse_url(url) { - Result::Ok(x) => x, - Result::Err(err) => { - tracing::warn!(?err, "could not parse runner connection url"); - return Err((client_ws, err)); - } - }; + let url = url::Url::parse(&format!("ws://placeholder/{path}")) + .context("failed to parse WebSocket URL")?; + let url_data = + utils::UrlData::parse_url(url).context("failed to extract URL parameters")?; tracing::info!(?path, "tunnel ws connection established"); // Accept WS - let ws_stream = match client_ws.await { - Result::Ok(ws) => ws, - Err(e) => { - // Handshake already in progress; cannot retry safely here - tracing::error!(error=?e, "client websocket await failed"); - return Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()); - } - }; - let (ws_tx, mut ws_rx) = ws_stream.split(); + let mut ws_rx = ws_handle + .accept() + .await + .context("failed to accept WebSocket connection")?; // Create connection - let mut ws_tx = Some(ws_tx); - let conn = match conn::init_conn(&self.ctx, &mut ws_tx, &mut ws_rx, url_data).await { - Ok(x) => x, - - Err(err) => { - tracing::warn!(?err, "failed to build connection"); - - if let Some(mut tx) = ws_tx { - let close_frame = utils::err_to_close_frame(err); - - if let Err(err) = tx.send(Message::Close(Some(close_frame))).await { - tracing::error!(?err, "failed closing socket"); - } - } - - return Ok(()); - } - }; + let conn = conn::init_conn(&self.ctx, ws_handle.clone(), &mut ws_rx, url_data) + .await + .context("failed to initialize runner connection")?; // Subscribe to pubsub topic for this runner before accepting the client websocket so // that failures can be retried by the proxy. let topic = pegboard::pubsub_subjects::RunnerReceiverSubject::new(conn.runner_id).to_string(); tracing::info!(%topic, "subscribing to runner receiver topic"); - let mut sub = match ups.subscribe(&topic).await { - Result::Ok(s) => s, - Err(err) => { - // TODO: Handle this error correctly - tracing::error!(?err, "failed to subscribe to runner receiver"); - return Ok(()); - } - }; + let sub = ups + .subscribe(&topic) + .await + .with_context(|| format!("failed to subscribe to runner receiver topic: {}", topic))?; // Forward pubsub -> WebSocket - let pubsub_to_client = tokio::spawn(pubsub_to_client_task::task( + let mut pubsub_to_client = tokio::spawn(pubsub_to_client_task::task( self.ctx.clone(), conn.clone(), sub, )); // Forward WebSocket -> pubsub - let client_to_pubsub = tokio::spawn(client_to_pubsub_task::task( + let mut client_to_pubsub = tokio::spawn(client_to_pubsub_task::task( self.ctx.clone(), conn.clone(), ws_rx, )); // Update pings - let ping = tokio::spawn(ping_task::task(self.ctx.clone(), conn.clone())); + let mut ping = tokio::spawn(ping_task::task(self.ctx.clone(), conn.clone())); // Wait for either task to complete tokio::select! { - _ = pubsub_to_client => { + _ = &mut pubsub_to_client => { tracing::info!("pubsub to WebSocket task completed"); } - _ = client_to_pubsub => { + _ = &mut client_to_pubsub => { tracing::info!("WebSocket to pubsub task completed"); } - _ = ping => { + _ = &mut ping => { tracing::info!("ping task completed"); } } + // Abort remaining tasks + pubsub_to_client.abort(); + client_to_pubsub.abort(); + ping.abort(); + // Make runner immediately ineligible when it disconnects - if let Err(err) = self - .ctx + self.ctx .op(pegboard::ops::runner::update_alloc_idx::Input { runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { runner_id: conn.runner_id, @@ -172,21 +136,20 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { }], }) .await - { - tracing::error!(?conn.runner_id, ?err, "failed evicting runner from alloc idx"); - } - - // TODO: Handle errors - // // Close WS - // let close_frame = utils::err_to_close_frame(err); - // let mut tx = conn.ws_tx.lock().await; - // if let Err(err) = tx.send(Message::Close(Some(close_frame))).await { - // tracing::error!(?runner_id, ?err, "failed closing socket"); - // } + .map_err(|err| { + // Log the error with full context but continue cleanup + tracing::error!( + ?conn.runner_id, + ?err, + "critical: failed to evict runner from allocation index during disconnect" + ); + err + }) + .ok(); // Clean up tracing::info!(?conn.runner_id, "connection closed"); - Result::<(), (HyperWebsocket, anyhow::Error)>::Ok(()) + Ok(()) } } diff --git a/packages/core/pegboard-runner/src/pubsub_to_client_task.rs b/packages/core/pegboard-runner/src/pubsub_to_client_task.rs index 2c198c98fb..98e1dcd460 100644 --- a/packages/core/pegboard-runner/src/pubsub_to_client_task.rs +++ b/packages/core/pegboard-runner/src/pubsub_to_client_task.rs @@ -53,7 +53,7 @@ async fn task_inner(ctx: StandaloneCtx, conn: Arc, mut sub: Subscriber) -> } }; let ws_msg = WsMessage::Binary(serialized_msg.into()); - if let Err(e) = conn.ws_tx.lock().await.send(ws_msg).await { + if let Err(e) = conn.ws_handle.send(ws_msg).await { tracing::error!(?e, "failed to send message to WebSocket"); break; } diff --git a/packages/core/pegboard-runner/src/utils.rs b/packages/core/pegboard-runner/src/utils.rs index 4e7243d8f2..aab6f7623c 100644 --- a/packages/core/pegboard-runner/src/utils.rs +++ b/packages/core/pegboard-runner/src/utils.rs @@ -9,11 +9,6 @@ use tokio_tungstenite::{ tungstenite::protocol::frame::{CloseFrame, coding::CloseCode}, }; -pub type WebSocketReceiver = futures_util::stream::SplitStream>>; - -pub type WebSocketSender = - futures_util::stream::SplitSink>, WsMessage>; - #[derive(Clone)] pub struct UrlData { pub protocol_version: u16, @@ -53,29 +48,6 @@ impl UrlData { } } -pub fn err_to_close_frame(err: anyhow::Error) -> CloseFrame { - let rivet_err = err - .chain() - .find_map(|x| x.downcast_ref::()) - .cloned() - .unwrap_or_else(|| RivetError::from(&INTERNAL_ERROR)); - - let code = match (rivet_err.group(), rivet_err.code()) { - ("ws", "connection_closed") => CloseCode::Normal, - _ => CloseCode::Error, - }; - - // NOTE: reason cannot be more than 123 bytes as per the WS protocol - let reason = util::safe_slice( - &format!("{}.{}", rivet_err.group(), rivet_err.code()), - 0, - 123, - ) - .into(); - - CloseFrame { code, reason } -} - /// Determines if a given message kind will terminate the request. pub fn is_to_server_tunnel_message_kind_request_close( kind: &protocol::ToServerTunnelMessageKind,