diff --git a/packages/core/guard/core/src/proxy_service.rs b/packages/core/guard/core/src/proxy_service.rs index 0cd6948709..41cf6a4f29 100644 --- a/packages/core/guard/core/src/proxy_service.rs +++ b/packages/core/guard/core/src/proxy_service.rs @@ -1169,7 +1169,6 @@ impl ProxyService { } // Handle WebSocket upgrade properly with hyper_tungstenite - // First, upgrade the client connection tracing::debug!("Upgrading client connection to WebSocket"); let (client_response, client_websocket) = match hyper_tungstenite::upgrade(req, None) { Result::Ok(x) => { @@ -1928,7 +1927,15 @@ impl ProxyService { // structure but convert it to our expected return type without modifying its content tracing::debug!("Returning WebSocket upgrade response to client"); // Extract the parts from the response but preserve all headers and status - let (parts, _) = client_response.into_parts(); + let (mut parts, _) = client_response.into_parts(); + + // Add Sec-WebSocket-Protocol header to the response + // Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response + parts.headers.insert( + "sec-websocket-protocol", + hyper::header::HeaderValue::from_static("rivet"), + ); + // Create a new response with an empty body - WebSocket upgrades don't need a body Ok(Response::from_parts( parts, diff --git a/packages/core/guard/server/src/routing/mod.rs b/packages/core/guard/server/src/routing/mod.rs index cc43c75dbe..02675ae8c9 100644 --- a/packages/core/guard/server/src/routing/mod.rs +++ b/packages/core/guard/server/src/routing/mod.rs @@ -12,6 +12,9 @@ pub mod pegboard_gateway; mod runner; pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target"); +pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName = + HeaderName::from_static("sec-websocket-protocol"); +pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target."; /// Creates the main routing function that handles all incoming requests #[tracing::instrument(skip_all)] @@ -31,9 +34,33 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> tracing::debug!("Routing request for hostname: {host}, path: {path}"); + // Check if this is a WebSocket upgrade request + let is_websocket = headers + .get("upgrade") + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + + // Extract target from WebSocket protocol or HTTP header + let target = if is_websocket { + // For WebSocket, parse the sec-websocket-protocol header + headers + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocols| protocols.to_str().ok()) + .and_then(|protocols| { + // Parse protocols to find target.{value} + protocols + .split(',') + .map(|p| p.trim()) + .find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET)) + }) + } else { + // For HTTP, use the x-rivet-target header + headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok()) + }; + // Read target - if let Some(target) = headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok()) - { + if let Some(target) = target { if let Some(routing_output) = runner::route_request(&ctx, target, host, path).await? { @@ -47,6 +74,7 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> host, path, headers, + is_websocket, ) .await? { diff --git a/packages/core/guard/server/src/routing/pegboard_gateway.rs b/packages/core/guard/server/src/routing/pegboard_gateway.rs index 519debc481..02bf19ca17 100644 --- a/packages/core/guard/server/src/routing/pegboard_gateway.rs +++ b/packages/core/guard/server/src/routing/pegboard_gateway.rs @@ -10,6 +10,8 @@ use crate::{errors, shared_state::SharedState}; const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10); pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); +const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol"); +const WS_PROTOCOL_ACTOR: &str = "rivet_actor."; /// Route requests to actor services based on hostname and path #[tracing::instrument(skip_all)] @@ -20,20 +22,47 @@ pub async fn route_request( _host: &str, path: &str, headers: &hyper::HeaderMap, + is_websocket: bool, ) -> Result> { // Check target if target != "actor" { return Ok(None); } + // Extract actor ID from WebSocket protocol or HTTP header + let actor_id_str = if is_websocket { + // For WebSocket, parse the sec-websocket-protocol header + headers + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocols| protocols.to_str().ok()) + .and_then(|protocols| { + // Parse protocols to find actor.{id} + protocols + .split(',') + .map(|p| p.trim()) + .find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR)) + }) + .ok_or_else(|| { + crate::errors::MissingHeader { + header: "actor protocol in sec-websocket-protocol".to_string(), + } + .build() + })? + } else { + // For HTTP, use the x-rivet-actor header + headers + .get(X_RIVET_ACTOR) + .and_then(|x| x.to_str().ok()) + .ok_or_else(|| { + crate::errors::MissingHeader { + header: X_RIVET_ACTOR.to_string(), + } + .build() + })? + }; + // Find actor to route to - let actor_id_str = headers.get(X_RIVET_ACTOR).ok_or_else(|| { - crate::errors::MissingHeader { - header: X_RIVET_ACTOR.to_string(), - } - .build() - })?; - let actor_id = Id::parse(actor_id_str.to_str()?)?; + let actor_id = Id::parse(actor_id_str)?; // Route to peer dc where the actor lives if actor_id.label() != ctx.config().dc_label() { diff --git a/packages/core/pegboard-gateway/src/lib.rs b/packages/core/pegboard-gateway/src/lib.rs index eef00f0a6c..274674a2fc 100644 --- a/packages/core/pegboard-gateway/src/lib.rs +++ b/packages/core/pegboard-gateway/src/lib.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use futures_util::StreamExt; use gas::prelude::*; use http_body_util::{BodyExt, Full}; -use hyper::{Request, Response, StatusCode}; +use hyper::{Request, Response, StatusCode, header::HeaderName}; use rivet_guard_core::{ WebSocketHandle, custom_serve::CustomServeTrait, @@ -22,6 +22,8 @@ use crate::shared_state::{SharedState, TunnelMessageData}; pub mod shared_state; const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2); +const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol"); +const WS_PROTOCOL_ACTOR: &str = "rivet_actor."; pub struct PegboardGateway { ctx: StandaloneCtx, @@ -94,7 +96,7 @@ impl PegboardGateway { req: Request>, _request_context: &mut RequestContext, ) -> Result> { - // Extract actor ID for the message + // Extract actor ID for the message (HTTP requests use x-rivet-actor header) let actor_id = req .headers() .get("x-rivet-actor") @@ -200,11 +202,18 @@ impl PegboardGateway { path: &str, _request_context: &mut RequestContext, ) -> Result<()> { - // Extract actor ID for the message + // Extract actor ID from WebSocket protocol let actor_id = headers - .get("x-rivet-actor") - .context("missing x-rivet-actor")? - .to_str()? + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocols| protocols.to_str().ok()) + .and_then(|protocols| { + // Parse protocols to find actor.{id} + protocols + .split(',') + .map(|p| p.trim()) + .find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR)) + }) + .context("missing actor protocol in sec-websocket-protocol")? .to_string(); // Extract headers diff --git a/packages/infra/engine/tests/common/actors.rs b/packages/infra/engine/tests/common/actors.rs index d3495c6dc5..9c27f74a87 100644 --- a/packages/infra/engine/tests/common/actors.rs +++ b/packages/infra/engine/tests/common/actors.rs @@ -427,20 +427,20 @@ pub async fn ping_actor_websocket_via_guard(guard_port: u16, actor_id: &str) -> "testing websocket connection to actor via guard" ); - // Build WebSocket URL and request + // Build WebSocket URL and request with protocols for routing let ws_url = format!("ws://127.0.0.1:{}/ws", guard_port); let mut request = ws_url .clone() .into_client_request() .expect("Failed to create WebSocket request"); - // Add headers for routing through guard to actor - request - .headers_mut() - .insert("X-Rivet-Target", "actor".parse().unwrap()); - request - .headers_mut() - .insert("X-Rivet-Actor", actor_id.parse().unwrap()); + // Add protocols for routing through guard to actor + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + format!("rivet, rivet_target.actor, rivet_actor.{}", actor_id) + .parse() + .unwrap(), + ); // Connect to WebSocket let (ws_stream, response) = connect_async(request) diff --git a/scripts/tests/actor_e2e.ts b/scripts/tests/actor_e2e.ts index 0290460e2f..d4ed1c6cf7 100755 --- a/scripts/tests/actor_e2e.ts +++ b/scripts/tests/actor_e2e.ts @@ -60,12 +60,8 @@ function testWebSocket(actorId: string): Promise { console.log(`Connecting WebSocket to: ${wsUrl}`); - const ws = new WebSocket(wsUrl, { - headers: { - "X-Rivet-Target": "actor", - "X-Rivet-Actor": actorId, - }, - }); + const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`]; + const ws = new WebSocket(wsUrl, protocols); let pingReceived = false; let echoReceived = false; diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 16c3700098..5081588f13 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -410,11 +410,8 @@ export class Runner { // MARK: Runner protocol async #openPegboardWebSocket() { const WS = await importWebSocket(); - const ws = new WS(this.pegboardUrl, { - headers: { - "x-rivet-target": "runner", - }, - }) as any as WebSocket; + const protocols = ["rivet", `rivet_target.runner`]; + const ws = new WS(this.pegboardUrl, protocols) as any as WebSocket; this.#pegboardWebSocket = ws; ws.addEventListener("open", () => { @@ -545,7 +542,7 @@ export class Runner { }); ws.addEventListener("error", (ev) => { - logger()?.error("WebSocket error:", ev.error); + logger()?.error(`WebSocket error: ${ev.error}`); }); ws.addEventListener("close", (ev) => {