Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions packages/core/guard/core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,6 @@ impl ProxyService {
}

// Handle WebSocket upgrade properly with hyper_tungstenite
// First, upgrade the client connection
tracing::debug!("Upgrading client connection to WebSocket");
let (client_response, client_websocket) = match hyper_tungstenite::upgrade(req, None) {
Result::Ok(x) => {
Expand Down Expand Up @@ -1928,7 +1927,15 @@ impl ProxyService {
// structure but convert it to our expected return type without modifying its content
tracing::debug!("Returning WebSocket upgrade response to client");
// Extract the parts from the response but preserve all headers and status
let (parts, _) = client_response.into_parts();
let (mut parts, _) = client_response.into_parts();

// Add Sec-WebSocket-Protocol header to the response
// Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response
parts.headers.insert(
"sec-websocket-protocol",
hyper::header::HeaderValue::from_static("rivet"),
);

// Create a new response with an empty body - WebSocket upgrades don't need a body
Ok(Response::from_parts(
parts,
Expand Down
32 changes: 30 additions & 2 deletions packages/core/guard/server/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub mod pegboard_gateway;
mod runner;

pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target");
pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName =
HeaderName::from_static("sec-websocket-protocol");
pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target.";

/// Creates the main routing function that handles all incoming requests
#[tracing::instrument(skip_all)]
Expand All @@ -31,9 +34,33 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->

tracing::debug!("Routing request for hostname: {host}, path: {path}");

// Check if this is a WebSocket upgrade request
let is_websocket = headers
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);

// Extract target from WebSocket protocol or HTTP header
let target = if is_websocket {
// For WebSocket, parse the sec-websocket-protocol header
headers
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.and_then(|protocols| {
// Parse protocols to find target.{value}
protocols
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
})
} else {
// For HTTP, use the x-rivet-target header
headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
};

// Read target
if let Some(target) = headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
{
if let Some(target) = target {
if let Some(routing_output) =
runner::route_request(&ctx, target, host, path).await?
{
Expand All @@ -47,6 +74,7 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
host,
path,
headers,
is_websocket,
)
.await?
{
Expand Down
43 changes: 36 additions & 7 deletions packages/core/guard/server/src/routing/pegboard_gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use crate::{errors, shared_state::SharedState};

const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10);
pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol");
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";

/// Route requests to actor services based on hostname and path
#[tracing::instrument(skip_all)]
Expand All @@ -20,20 +22,47 @@ pub async fn route_request(
_host: &str,
path: &str,
headers: &hyper::HeaderMap,
is_websocket: bool,
) -> Result<Option<RoutingOutput>> {
// Check target
if target != "actor" {
return Ok(None);
}

// Extract actor ID from WebSocket protocol or HTTP header
let actor_id_str = if is_websocket {
// For WebSocket, parse the sec-websocket-protocol header
headers
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.and_then(|protocols| {
// Parse protocols to find actor.{id}
protocols
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
})
.ok_or_else(|| {
crate::errors::MissingHeader {
header: "actor protocol in sec-websocket-protocol".to_string(),
}
.build()
})?
} else {
// For HTTP, use the x-rivet-actor header
headers
.get(X_RIVET_ACTOR)
.and_then(|x| x.to_str().ok())
.ok_or_else(|| {
crate::errors::MissingHeader {
header: X_RIVET_ACTOR.to_string(),
}
.build()
})?
};

// Find actor to route to
let actor_id_str = headers.get(X_RIVET_ACTOR).ok_or_else(|| {
crate::errors::MissingHeader {
header: X_RIVET_ACTOR.to_string(),
}
.build()
})?;
let actor_id = Id::parse(actor_id_str.to_str()?)?;
let actor_id = Id::parse(actor_id_str)?;

// Route to peer dc where the actor lives
if actor_id.label() != ctx.config().dc_label() {
Expand Down
21 changes: 15 additions & 6 deletions packages/core/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use bytes::Bytes;
use futures_util::StreamExt;
use gas::prelude::*;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use hyper::{Request, Response, StatusCode, header::HeaderName};
use rivet_guard_core::{
WebSocketHandle,
custom_serve::CustomServeTrait,
Expand All @@ -22,6 +22,8 @@ use crate::shared_state::{SharedState, TunnelMessageData};
pub mod shared_state;

const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2);
const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol");
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";

pub struct PegboardGateway {
ctx: StandaloneCtx,
Expand Down Expand Up @@ -94,7 +96,7 @@ impl PegboardGateway {
req: Request<Full<Bytes>>,
_request_context: &mut RequestContext,
) -> Result<Response<ResponseBody>> {
// Extract actor ID for the message
// Extract actor ID for the message (HTTP requests use x-rivet-actor header)
let actor_id = req
.headers()
.get("x-rivet-actor")
Expand Down Expand Up @@ -200,11 +202,18 @@ impl PegboardGateway {
path: &str,
_request_context: &mut RequestContext,
) -> Result<()> {
// Extract actor ID for the message
// Extract actor ID from WebSocket protocol
let actor_id = headers
.get("x-rivet-actor")
.context("missing x-rivet-actor")?
.to_str()?
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.and_then(|protocols| {
// Parse protocols to find actor.{id}
protocols
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
})
.context("missing actor protocol in sec-websocket-protocol")?
.to_string();

// Extract headers
Expand Down
16 changes: 8 additions & 8 deletions packages/infra/engine/tests/common/actors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,20 @@ pub async fn ping_actor_websocket_via_guard(guard_port: u16, actor_id: &str) ->
"testing websocket connection to actor via guard"
);

// Build WebSocket URL and request
// Build WebSocket URL and request with protocols for routing
let ws_url = format!("ws://127.0.0.1:{}/ws", guard_port);
let mut request = ws_url
.clone()
.into_client_request()
.expect("Failed to create WebSocket request");

// Add headers for routing through guard to actor
request
.headers_mut()
.insert("X-Rivet-Target", "actor".parse().unwrap());
request
.headers_mut()
.insert("X-Rivet-Actor", actor_id.parse().unwrap());
// Add protocols for routing through guard to actor
request.headers_mut().insert(
"Sec-WebSocket-Protocol",
format!("rivet, rivet_target.actor, rivet_actor.{}", actor_id)
.parse()
.unwrap(),
);

// Connect to WebSocket
let (ws_stream, response) = connect_async(request)
Expand Down
8 changes: 2 additions & 6 deletions scripts/tests/actor_e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ function testWebSocket(actorId: string): Promise<void> {

console.log(`Connecting WebSocket to: ${wsUrl}`);

const ws = new WebSocket(wsUrl, {
headers: {
"X-Rivet-Target": "actor",
"X-Rivet-Actor": actorId,
},
});
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`];
const ws = new WebSocket(wsUrl, protocols);

let pingReceived = false;
let echoReceived = false;
Expand Down
9 changes: 3 additions & 6 deletions sdks/typescript/runner/src/mod.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading