Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions packages/core/guard/core/src/custom_serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use anyhow::*;
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming as BodyIncoming;
use hyper::{Request, Response};
use hyper_tungstenite::HyperWebsocket;

use crate::WebSocketHandle;
use crate::proxy_service::ResponseBody;
use crate::request_context::RequestContext;

Expand All @@ -19,19 +19,12 @@ pub trait CustomServeTrait: Send + Sync {
request_context: &mut RequestContext,
) -> Result<Response<ResponseBody>>;

/// Handle a WebSocket connection after upgrade.
///
/// Contract for retries:
/// - Return `Ok(())` after you have accepted (`await`ed) the client websocket and
/// completed the streaming lifecycle. No further retries are possible.
/// - Return `Err((client_ws, err))` if you have NOT accepted the websocket yet and
/// want the proxy to optionally re-resolve and retry with a different handler.
/// You must not `await` the websocket before returning this error.
/// Handle a WebSocket connection after upgrade. Supports connection retries.
async fn handle_websocket(
&self,
client_ws: HyperWebsocket,
websocket: WebSocketHandle,
headers: &hyper::HeaderMap,
path: &str,
request_context: &mut RequestContext,
) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)>;
) -> Result<()>;
}
2 changes: 2 additions & 0 deletions packages/core/guard/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ pub mod proxy_service;
pub mod request_context;
mod server;
pub mod types;
pub mod websocket_handle;

pub use cert_resolver::CertResolverFn;
pub use custom_serve::CustomServeTrait;
pub use proxy_service::{
CacheKeyFn, MiddlewareFn, ProxyService, ProxyState, RouteTarget, RoutingFn, RoutingOutput,
};
pub use websocket_handle::WebSocketHandle;

// Re-export hyper StatusCode for use in other crates
pub mod status {
Expand Down
154 changes: 91 additions & 63 deletions packages/core/guard/core/src/proxy_service.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
use std::{
borrow::Cow,
collections::HashMap as StdHashMap,
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};

use anyhow::*;
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
Expand All @@ -16,17 +8,30 @@ use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use moka::future::Cache;
use rand;
use rivet_api_builder::{ErrorResponse, RawErrorResponse};
use rivet_error::RivetError;
use rivet_error::{INTERNAL_ERROR, RivetError};
use rivet_metrics::KeyValue;
use rivet_util::Id;
use serde_json;
use std::{
borrow::Cow,
collections::HashMap as StdHashMap,
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::{
client::IntoClientRequest,
protocol::{CloseFrame, frame::coding::CloseCode},
};
use tracing::Instrument;
use url::Url;

use crate::{custom_serve::CustomServeTrait, errors, metrics, request_context::RequestContext};
use crate::{
WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics,
request_context::RequestContext,
};

pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error");
Expand Down Expand Up @@ -1432,9 +1437,9 @@ impl ProxyService {

// Close the WebSocket connection with the response message
let _ = client_ws.close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: response.message.as_ref().into(),
})).await;
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: response.message.as_ref().into(),
})).await;
return;
}
Result::Ok(ResolveRouteOutput::CustomServe(_)) => {
Expand Down Expand Up @@ -1813,31 +1818,42 @@ impl ProxyService {
let mut attempts = 0u32;
let mut client_ws = client_websocket;

let ws_handle = WebSocketHandle::new(client_ws);

loop {
match handlers
.handle_websocket(
client_ws,
ws_handle.clone(),
&req_headers,
&req_path,
&mut request_context,
)
.await
{
Result::Ok(()) => break,
Result::Err((returned_client_ws, err)) => {
Result::Ok(()) => {
tracing::debug!("websocket closed");

// Send graceful close
ws_handle.send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
reason: format!("Closed").into(),
},
)));

break;
}
Result::Err(err) => {
attempts += 1;
if attempts > max_attempts || !is_retryable_ws_error(&err) {
// Accept and close the client websocket with an error reason
if let Result::Ok(mut ws) = returned_client_ws.await {
let _ = ws
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: format!("{}", err).into(),
},
)))
.await;
}
// Close WebSocket with error
ws_handle
.accept_and_send(
hyper_tungstenite::tungstenite::Message::Close(
Some(err_to_close_frame(err)),
),
)
.await?;

break;
} else {
Expand All @@ -1861,49 +1877,38 @@ impl ProxyService {
new_handlers,
)) => {
handlers = new_handlers;
client_ws = returned_client_ws;
continue;
}
Result::Ok(ResolveRouteOutput::Response(response)) => {
if let Result::Ok(mut ws) = returned_client_ws.await
{
let _ = ws
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: response.message.as_ref().into(),
},
)))
.await;
}
break;
ws_handle
.accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: response.message.as_ref().into(),
},
)))
.await;
}
Result::Ok(ResolveRouteOutput::Target(_)) => {
if let Result::Ok(mut ws) = returned_client_ws.await
{
let _ = ws
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: "Cannot retry WebSocket with non-custom serve route".into(),
},
)))
.await;
}
ws_handle
.accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: "Cannot retry WebSocket with non-custom serve route".into(),
},
)))
.await;
break;
}
Err(res_err) => {
if let Result::Ok(mut ws) = returned_client_ws.await
{
let _ = ws
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: format!("Routing error: {}", res_err).into(),
},
)))
.await;
}
ws_handle
.accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some(
hyper_tungstenite::tungstenite::protocol::CloseFrame {
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
reason: format!("Routing error: {}", res_err).into(),
},
)))
.await;
break;
}
}
Expand Down Expand Up @@ -2242,3 +2247,26 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool {
false
}
}

pub fn err_to_close_frame(err: anyhow::Error) -> CloseFrame {
let rivet_err = err
.chain()
.find_map(|x| x.downcast_ref::<RivetError>())
.cloned()
.unwrap_or_else(|| RivetError::from(&INTERNAL_ERROR));

let code = match (rivet_err.group(), rivet_err.code()) {
("ws", "connection_closed") => CloseCode::Normal,
_ => CloseCode::Error,
};

// NOTE: reason cannot be more than 123 bytes as per the WS protocol
let reason = rivet_util::safe_slice(
&format!("{}.{}", rivet_err.group(), rivet_err.code()),
0,
123,
)
.into();

CloseFrame { code, reason }
}
104 changes: 104 additions & 0 deletions packages/core/guard/core/src/websocket_handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use anyhow::*;
use futures_util::{SinkExt, StreamExt};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::HyperWebsocket;
use hyper_tungstenite::tungstenite::Message as WsMessage;
use hyper_util::rt::TokioIo;
use std::ops::Deref;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_tungstenite::WebSocketStream;

pub type WebSocketReceiver = futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>;

pub type WebSocketSender =
futures_util::stream::SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage>;

enum WebSocketState {
Unaccepted { websocket: HyperWebsocket },
Accepting,
Split { ws_tx: WebSocketSender },
}

#[derive(Clone)]
pub struct WebSocketHandle(Arc<WebSocketHandleInner>);

impl WebSocketHandle {
pub fn new(websocket: HyperWebsocket) -> Self {
Self(Arc::new(WebSocketHandleInner {
state: Mutex::new(WebSocketState::Unaccepted { websocket }),
}))
}
}

impl Deref for WebSocketHandle {
type Target = WebSocketHandleInner;

fn deref(&self) -> &Self::Target {
&*self.0
}
}

pub struct WebSocketHandleInner {
state: Mutex<WebSocketState>,
}

impl WebSocketHandleInner {
pub async fn accept(&self) -> Result<WebSocketReceiver> {
let mut state = self.state.lock().await;
Self::accept_inner(&mut *state).await
}

pub async fn send(&self, message: WsMessage) -> Result<()> {
let mut state = self.state.lock().await;
match &mut *state {
WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => {
bail!("websocket has not been accepted")
}
WebSocketState::Split { ws_tx } => {
ws_tx.send(message).await?;
Ok(())
}
}
}

pub async fn accept_and_send(&self, message: WsMessage) -> Result<()> {
let mut state = self.state.lock().await;
match &mut *state {
WebSocketState::Unaccepted { .. } => {
let _ = Self::accept_inner(&mut *state).await?;
let WebSocketState::Split { ws_tx } = &mut *state else {
bail!("websocket should be accepted");
};
ws_tx.send(message).await?;
Ok(())
}
WebSocketState::Accepting => {
bail!("in accepting state")
}
WebSocketState::Split { ws_tx } => {
ws_tx.send(message).await?;
Ok(())
}
}
}

async fn accept_inner(state: &mut WebSocketState) -> Result<WebSocketReceiver> {
if !matches!(*state, WebSocketState::Unaccepted { .. }) {
bail!("websocket already accepted")
}

// Accept websocket
let old_state = std::mem::replace(&mut *state, WebSocketState::Accepting);
let WebSocketState::Unaccepted { websocket } = old_state else {
bail!("should be in unaccepted state");
};

// Accept WS
let ws_stream = websocket.await?;
let (ws_tx, ws_rx) = ws_stream.split();
*state = WebSocketState::Split { ws_tx };

Ok(ws_rx)
}
}
Loading
Loading