diff --git a/packages/core/guard/server/src/routing/pegboard_gateway.rs b/packages/core/guard/server/src/routing/pegboard_gateway.rs index 02bf19ca17..2a5f4cf3d7 100644 --- a/packages/core/guard/server/src/routing/pegboard_gateway.rs +++ b/packages/core/guard/server/src/routing/pegboard_gateway.rs @@ -6,11 +6,11 @@ use hyper::header::HeaderName; use rivet_guard_core::proxy_service::{RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout}; use universaldb::utils::IsolationLevel::*; +use super::SEC_WEBSOCKET_PROTOCOL; use crate::{errors, shared_state::SharedState}; const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10); pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); -const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol"); const WS_PROTOCOL_ACTOR: &str = "rivet_actor."; /// Route requests to actor services based on hostname and path diff --git a/packages/core/guard/server/src/routing/runner.rs b/packages/core/guard/server/src/routing/runner.rs index be1a5a628d..e04af72e4f 100644 --- a/packages/core/guard/server/src/routing/runner.rs +++ b/packages/core/guard/server/src/routing/runner.rs @@ -3,7 +3,8 @@ use gas::prelude::*; use rivet_guard_core::proxy_service::RoutingOutput; use std::sync::Arc; -use super::X_RIVET_TOKEN; +use super::{SEC_WEBSOCKET_PROTOCOL, X_RIVET_TOKEN}; +pub(crate) const WS_PROTOCOL_TOKEN: &str = "rivet_target."; /// Route requests to the API service #[tracing::instrument(skip_all)] @@ -18,11 +19,33 @@ pub async fn route_request( return Ok(None); } + let is_websocket = headers + .get("upgrade") + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + // Check auth (if enabled) if let Some(auth) = &ctx.config().auth { let token = headers .get(X_RIVET_TOKEN) .and_then(|x| x.to_str().ok()) + // Fallback to checking websocket protocol if rivet token is not set + .or_else(|| { + if is_websocket { + headers + .get(SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocols| protocols.to_str().ok()) + .and_then(|protocols| { + protocols + .split(',') + .map(|p| p.trim()) + .find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN)) + }) + } else { + None + } + }) .ok_or_else(|| { crate::errors::MissingHeader { header: X_RIVET_TOKEN.to_string(), diff --git a/sdks/typescript/runner/src/mod.ts b/sdks/typescript/runner/src/mod.ts index 5081588f13..f19c438469 100644 --- a/sdks/typescript/runner/src/mod.ts +++ b/sdks/typescript/runner/src/mod.ts @@ -28,6 +28,7 @@ export interface RunnerConfig { logger?: Logger; version: number; endpoint: string; + token?: string; pegboardEndpoint?: string; pegboardRelayEndpoint?: string; namespace: string; @@ -409,8 +410,10 @@ export class Runner { // MARK: Runner protocol async #openPegboardWebSocket() { - const WS = await importWebSocket(); const protocols = ["rivet", `rivet_target.runner`]; + if (this.config.token) protocols.push(`rivet_token.${this.config.token}`); + + const WS = await importWebSocket(); const ws = new WS(this.pegboardUrl, protocols) as any as WebSocket; this.#pegboardWebSocket = ws;