Skip to content

Commit ab4ac1d

Browse files
NathanFlurryjog1t
authored andcommitted
chore(guard): simplify error handling for websockets
1 parent b45d0f4 commit ab4ac1d

File tree

11 files changed

+400
-316
lines changed

11 files changed

+400
-316
lines changed

packages/core/guard/core/src/custom_serve.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use anyhow::*;
22
use async_trait::async_trait;
33
use bytes::Bytes;
44
use http_body_util::Full;
5-
use hyper::body::Incoming as BodyIncoming;
65
use hyper::{Request, Response};
76
use hyper_tungstenite::HyperWebsocket;
87

8+
use crate::WebSocketHandle;
99
use crate::proxy_service::ResponseBody;
1010
use crate::request_context::RequestContext;
1111

@@ -19,19 +19,12 @@ pub trait CustomServeTrait: Send + Sync {
1919
request_context: &mut RequestContext,
2020
) -> Result<Response<ResponseBody>>;
2121

22-
/// Handle a WebSocket connection after upgrade.
23-
///
24-
/// Contract for retries:
25-
/// - Return `Ok(())` after you have accepted (`await`ed) the client websocket and
26-
/// completed the streaming lifecycle. No further retries are possible.
27-
/// - Return `Err((client_ws, err))` if you have NOT accepted the websocket yet and
28-
/// want the proxy to optionally re-resolve and retry with a different handler.
29-
/// You must not `await` the websocket before returning this error.
22+
/// Handle a WebSocket connection after upgrade. Supports connection retries.
3023
async fn handle_websocket(
3124
&self,
32-
client_ws: HyperWebsocket,
25+
websocket: WebSocketHandle,
3326
headers: &hyper::HeaderMap,
3427
path: &str,
3528
request_context: &mut RequestContext,
36-
) -> std::result::Result<(), (HyperWebsocket, anyhow::Error)>;
29+
) -> Result<()>;
3730
}

packages/core/guard/core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ pub mod proxy_service;
77
pub mod request_context;
88
mod server;
99
pub mod types;
10+
pub mod websocket_handle;
1011

1112
pub use cert_resolver::CertResolverFn;
1213
pub use custom_serve::CustomServeTrait;
1314
pub use proxy_service::{
1415
CacheKeyFn, MiddlewareFn, ProxyService, ProxyState, RouteTarget, RoutingFn, RoutingOutput,
1516
};
17+
pub use websocket_handle::WebSocketHandle;
1618

1719
// Re-export hyper StatusCode for use in other crates
1820
pub mod status {

packages/core/guard/core/src/proxy_service.rs

Lines changed: 91 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
use std::{
2-
borrow::Cow,
3-
collections::HashMap as StdHashMap,
4-
net::SocketAddr,
5-
sync::Arc,
6-
time::{Duration, Instant},
7-
};
8-
91
use anyhow::*;
102
use bytes::Bytes;
113
use futures_util::{SinkExt, StreamExt};
@@ -16,17 +8,30 @@ use hyper_util::{client::legacy::Client, rt::TokioExecutor};
168
use moka::future::Cache;
179
use rand;
1810
use rivet_api_builder::{ErrorResponse, RawErrorResponse};
19-
use rivet_error::RivetError;
11+
use rivet_error::{INTERNAL_ERROR, RivetError};
2012
use rivet_metrics::KeyValue;
2113
use rivet_util::Id;
2214
use serde_json;
15+
use std::{
16+
borrow::Cow,
17+
collections::HashMap as StdHashMap,
18+
net::SocketAddr,
19+
sync::Arc,
20+
time::{Duration, Instant},
21+
};
2322
use tokio::sync::Mutex;
2423
use tokio::time::timeout;
25-
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
24+
use tokio_tungstenite::tungstenite::{
25+
client::IntoClientRequest,
26+
protocol::{CloseFrame, frame::coding::CloseCode},
27+
};
2628
use tracing::Instrument;
2729
use url::Url;
2830

29-
use crate::{custom_serve::CustomServeTrait, errors, metrics, request_context::RequestContext};
31+
use crate::{
32+
WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics,
33+
request_context::RequestContext,
34+
};
3035

3136
pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
3237
pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error");
@@ -1432,9 +1437,9 @@ impl ProxyService {
14321437

14331438
// Close the WebSocket connection with the response message
14341439
let _ = client_ws.close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
1435-
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1436-
reason: response.message.as_ref().into(),
1437-
})).await;
1440+
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1441+
reason: response.message.as_ref().into(),
1442+
})).await;
14381443
return;
14391444
}
14401445
Result::Ok(ResolveRouteOutput::CustomServe(_)) => {
@@ -1813,31 +1818,42 @@ impl ProxyService {
18131818
let mut attempts = 0u32;
18141819
let mut client_ws = client_websocket;
18151820

1821+
let ws_handle = WebSocketHandle::new(client_ws);
1822+
18161823
loop {
18171824
match handlers
18181825
.handle_websocket(
1819-
client_ws,
1826+
ws_handle.clone(),
18201827
&req_headers,
18211828
&req_path,
18221829
&mut request_context,
18231830
)
18241831
.await
18251832
{
1826-
Result::Ok(()) => break,
1827-
Result::Err((returned_client_ws, err)) => {
1833+
Result::Ok(()) => {
1834+
tracing::debug!("websocket closed");
1835+
1836+
// Send graceful close
1837+
ws_handle.send(hyper_tungstenite::tungstenite::Message::Close(Some(
1838+
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1839+
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
1840+
reason: format!("Closed").into(),
1841+
},
1842+
)));
1843+
1844+
break;
1845+
}
1846+
Result::Err(err) => {
18281847
attempts += 1;
18291848
if attempts > max_attempts || !is_retryable_ws_error(&err) {
1830-
// Accept and close the client websocket with an error reason
1831-
if let Result::Ok(mut ws) = returned_client_ws.await {
1832-
let _ = ws
1833-
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
1834-
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1835-
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1836-
reason: format!("{}", err).into(),
1837-
},
1838-
)))
1839-
.await;
1840-
}
1849+
// Close WebSocket with error
1850+
ws_handle
1851+
.accept_and_send(
1852+
hyper_tungstenite::tungstenite::Message::Close(
1853+
Some(err_to_close_frame(err)),
1854+
),
1855+
)
1856+
.await?;
18411857

18421858
break;
18431859
} else {
@@ -1861,49 +1877,38 @@ impl ProxyService {
18611877
new_handlers,
18621878
)) => {
18631879
handlers = new_handlers;
1864-
client_ws = returned_client_ws;
18651880
continue;
18661881
}
18671882
Result::Ok(ResolveRouteOutput::Response(response)) => {
1868-
if let Result::Ok(mut ws) = returned_client_ws.await
1869-
{
1870-
let _ = ws
1871-
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
1872-
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1873-
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1874-
reason: response.message.as_ref().into(),
1875-
},
1876-
)))
1877-
.await;
1878-
}
1879-
break;
1883+
ws_handle
1884+
.accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some(
1885+
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1886+
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1887+
reason: response.message.as_ref().into(),
1888+
},
1889+
)))
1890+
.await;
18801891
}
18811892
Result::Ok(ResolveRouteOutput::Target(_)) => {
1882-
if let Result::Ok(mut ws) = returned_client_ws.await
1883-
{
1884-
let _ = ws
1885-
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
1886-
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1887-
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1888-
reason: "Cannot retry WebSocket with non-custom serve route".into(),
1889-
},
1890-
)))
1891-
.await;
1892-
}
1893+
ws_handle
1894+
.accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some(
1895+
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1896+
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1897+
reason: "Cannot retry WebSocket with non-custom serve route".into(),
1898+
},
1899+
)))
1900+
.await;
18931901
break;
18941902
}
18951903
Err(res_err) => {
1896-
if let Result::Ok(mut ws) = returned_client_ws.await
1897-
{
1898-
let _ = ws
1899-
.send(hyper_tungstenite::tungstenite::Message::Close(Some(
1900-
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1901-
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1902-
reason: format!("Routing error: {}", res_err).into(),
1903-
},
1904-
)))
1905-
.await;
1906-
}
1904+
ws_handle
1905+
.accept_and_send(hyper_tungstenite::tungstenite::Message::Close(Some(
1906+
hyper_tungstenite::tungstenite::protocol::CloseFrame {
1907+
code: hyper_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Error,
1908+
reason: format!("Routing error: {}", res_err).into(),
1909+
},
1910+
)))
1911+
.await;
19071912
break;
19081913
}
19091914
}
@@ -2242,3 +2247,26 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool {
22422247
false
22432248
}
22442249
}
2250+
2251+
pub fn err_to_close_frame(err: anyhow::Error) -> CloseFrame {
2252+
let rivet_err = err
2253+
.chain()
2254+
.find_map(|x| x.downcast_ref::<RivetError>())
2255+
.cloned()
2256+
.unwrap_or_else(|| RivetError::from(&INTERNAL_ERROR));
2257+
2258+
let code = match (rivet_err.group(), rivet_err.code()) {
2259+
("ws", "connection_closed") => CloseCode::Normal,
2260+
_ => CloseCode::Error,
2261+
};
2262+
2263+
// NOTE: reason cannot be more than 123 bytes as per the WS protocol
2264+
let reason = rivet_util::safe_slice(
2265+
&format!("{}.{}", rivet_err.group(), rivet_err.code()),
2266+
0,
2267+
123,
2268+
)
2269+
.into();
2270+
2271+
CloseFrame { code, reason }
2272+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use anyhow::*;
2+
use futures_util::{SinkExt, StreamExt};
3+
use hyper::upgrade::Upgraded;
4+
use hyper_tungstenite::HyperWebsocket;
5+
use hyper_tungstenite::tungstenite::Message as WsMessage;
6+
use hyper_util::rt::TokioIo;
7+
use std::ops::Deref;
8+
use std::sync::Arc;
9+
use tokio::sync::Mutex;
10+
use tokio_tungstenite::WebSocketStream;
11+
12+
pub type WebSocketReceiver = futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>;
13+
14+
pub type WebSocketSender =
15+
futures_util::stream::SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage>;
16+
17+
enum WebSocketState {
18+
Unaccepted { websocket: HyperWebsocket },
19+
Accepting,
20+
Split { ws_tx: WebSocketSender },
21+
}
22+
23+
#[derive(Clone)]
24+
pub struct WebSocketHandle(Arc<WebSocketHandleInner>);
25+
26+
impl WebSocketHandle {
27+
pub fn new(websocket: HyperWebsocket) -> Self {
28+
Self(Arc::new(WebSocketHandleInner {
29+
state: Mutex::new(WebSocketState::Unaccepted { websocket }),
30+
}))
31+
}
32+
}
33+
34+
impl Deref for WebSocketHandle {
35+
type Target = WebSocketHandleInner;
36+
37+
fn deref(&self) -> &Self::Target {
38+
&*self.0
39+
}
40+
}
41+
42+
pub struct WebSocketHandleInner {
43+
state: Mutex<WebSocketState>,
44+
}
45+
46+
impl WebSocketHandleInner {
47+
pub async fn accept(&self) -> Result<WebSocketReceiver> {
48+
let mut state = self.state.lock().await;
49+
Self::accept_inner(&mut *state).await
50+
}
51+
52+
pub async fn send(&self, message: WsMessage) -> Result<()> {
53+
let mut state = self.state.lock().await;
54+
match &mut *state {
55+
WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => {
56+
bail!("websocket has not been accepted")
57+
}
58+
WebSocketState::Split { ws_tx } => {
59+
ws_tx.send(message).await?;
60+
Ok(())
61+
}
62+
}
63+
}
64+
65+
pub async fn accept_and_send(&self, message: WsMessage) -> Result<()> {
66+
let mut state = self.state.lock().await;
67+
match &mut *state {
68+
WebSocketState::Unaccepted { .. } => {
69+
let _ = Self::accept_inner(&mut *state).await?;
70+
let WebSocketState::Split { ws_tx } = &mut *state else {
71+
bail!("websocket should be accepted");
72+
};
73+
ws_tx.send(message).await?;
74+
Ok(())
75+
}
76+
WebSocketState::Accepting => {
77+
bail!("in accepting state")
78+
}
79+
WebSocketState::Split { ws_tx } => {
80+
ws_tx.send(message).await?;
81+
Ok(())
82+
}
83+
}
84+
}
85+
86+
async fn accept_inner(state: &mut WebSocketState) -> Result<WebSocketReceiver> {
87+
if !matches!(*state, WebSocketState::Unaccepted { .. }) {
88+
bail!("websocket already accepted")
89+
}
90+
91+
// Accept websocket
92+
let old_state = std::mem::replace(&mut *state, WebSocketState::Accepting);
93+
let WebSocketState::Unaccepted { websocket } = old_state else {
94+
bail!("should be in unaccepted state");
95+
};
96+
97+
// Accept WS
98+
let ws_stream = websocket.await?;
99+
let (ws_tx, ws_rx) = ws_stream.split();
100+
*state = WebSocketState::Split { ws_tx };
101+
102+
Ok(ws_rx)
103+
}
104+
}

0 commit comments

Comments
 (0)