diff --git a/engine/packages/gasoline/src/ctx/standalone.rs b/engine/packages/gasoline/src/ctx/standalone.rs index a24d2b096c..2a9bf05d1f 100644 --- a/engine/packages/gasoline/src/ctx/standalone.rs +++ b/engine/packages/gasoline/src/ctx/standalone.rs @@ -66,6 +66,23 @@ impl StandaloneCtx { }) } + #[tracing::instrument(skip_all)] + pub fn with_ray(&self, ray_id: Id, req_id: Id) -> WorkflowResult { + let mut ctx = StandaloneCtx::new( + self.db.clone(), + self.config.clone(), + self.pools.clone(), + self.cache.clone(), + &self.name, + ray_id, + req_id, + )?; + + ctx.from_workflow = self.from_workflow; + + Ok(ctx) + } + #[tracing::instrument(skip_all)] pub fn new_from_activity(ctx: &ActivityCtx, req_id: Id) -> WorkflowResult { let mut ctx = StandaloneCtx::new( diff --git a/engine/packages/guard-core/src/custom_serve.rs b/engine/packages/guard-core/src/custom_serve.rs index 99e2719dc4..461f901d6b 100644 --- a/engine/packages/guard-core/src/custom_serve.rs +++ b/engine/packages/guard-core/src/custom_serve.rs @@ -4,11 +4,11 @@ use bytes::Bytes; use http_body_util::Full; use hyper::{Request, Response}; use rivet_runner_protocol as protocol; +use rivet_util::Id; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use crate::WebSocketHandle; use crate::proxy_service::ResponseBody; -use crate::request_context::RequestContext; pub enum HibernationResult { Continue, @@ -22,7 +22,8 @@ pub trait CustomServeTrait: Send + Sync { async fn handle_request( &self, req: Request>, - request_context: &mut RequestContext, + ray_id: Id, + req_id: Id, request_id: protocol::RequestId, ) -> Result>; @@ -32,9 +33,10 @@ pub trait CustomServeTrait: Send + Sync { _websocket: WebSocketHandle, _headers: &hyper::HeaderMap, _path: &str, - _request_context: &mut RequestContext, + _ray_id: Id, + _req_id: Id, // Identifies the websocket across retries. - _unique_request_id: protocol::RequestId, + _request_id: protocol::RequestId, // True if this websocket is reconnecting after hibernation. _after_hibernation: bool, ) -> Result> { @@ -45,7 +47,9 @@ pub trait CustomServeTrait: Send + Sync { async fn handle_websocket_hibernation( &self, _websocket: WebSocketHandle, - _unique_request_id: protocol::RequestId, + _ray_id: Id, + _req_id: Id, + _request_id: protocol::RequestId, ) -> Result { bail!("service does not support websocket hibernation"); } diff --git a/engine/packages/guard-core/src/errors.rs b/engine/packages/guard-core/src/errors.rs index ebfa809da5..72aa979efa 100644 --- a/engine/packages/guard-core/src/errors.rs +++ b/engine/packages/guard-core/src/errors.rs @@ -1,5 +1,4 @@ use rivet_error::*; -use rivet_util::Id; use serde::{Deserialize, Serialize}; #[derive(RivetError, Serialize, Deserialize)] @@ -7,10 +6,9 @@ use serde::{Deserialize, Serialize}; "guard", "rate_limit", "Too many requests. Try again later.", - "Too many requests to '{method} {path}' (actor_id: {actor_id:?}) from IP {ip}." + "Too many requests to '{method} {path}' from IP {ip}." )] pub struct RateLimit { - pub actor_id: Option, pub method: String, pub path: String, pub ip: String, diff --git a/engine/packages/guard-core/src/lib.rs b/engine/packages/guard-core/src/lib.rs index 25f0d4a954..37bda09c07 100644 --- a/engine/packages/guard-core/src/lib.rs +++ b/engine/packages/guard-core/src/lib.rs @@ -4,7 +4,6 @@ pub mod custom_serve; pub mod errors; pub mod metrics; pub mod proxy_service; -pub mod request_context; mod server; mod task_group; pub mod types; diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index f0fb127b4f..f8560dd7d6 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -32,7 +32,6 @@ use crate::{ WebSocketHandle, custom_serve::{CustomServeTrait, HibernationResult}, errors, metrics, - request_context::RequestContext, task_group::TaskGroup, }; @@ -111,7 +110,6 @@ impl http_body::Body for ResponseBody { // Routing types #[derive(Clone, Debug)] pub struct RouteTarget { - pub actor_id: Option, pub host: String, pub port: u16, pub path: String, @@ -153,6 +151,8 @@ pub type RoutingFn = Arc< dyn for<'a> Fn( &'a str, &'a str, + Id, + Id, PortType, &'a hyper::HeaderMap, ) -> futures::future::BoxFuture<'a, Result> @@ -210,7 +210,6 @@ pub enum MiddlewareResponse { pub type MiddlewareFn = Arc< dyn for<'a> Fn( - &'a Id, &'a hyper::HeaderMap, ) -> futures::future::BoxFuture<'a, Result> + Send @@ -315,11 +314,10 @@ pub struct ProxyState { middleware_fn: MiddlewareFn, route_cache: RouteCache, // We use moka::Cache instead of scc::HashMap because it automatically handles TTL and capacity - rate_limiters: Cache<(Id, std::net::IpAddr), Arc>>, - in_flight_counters: Cache<(Id, std::net::IpAddr), Arc>>, + rate_limiters: Cache>>, + in_flight_counters: Cache>>, in_flight_requests: Cache, port_type: PortType, - clickhouse_inserter: Option, tasks: Arc, } @@ -330,7 +328,6 @@ impl ProxyState { cache_key_fn: CacheKeyFn, middleware_fn: MiddlewareFn, port_type: PortType, - clickhouse_inserter: Option, ) -> Self { Self { config, @@ -348,7 +345,6 @@ impl ProxyState { .build(), in_flight_requests: Cache::builder().max_capacity(10_000_000).build(), port_type, - clickhouse_inserter, tasks: TaskGroup::new(), } } @@ -358,6 +354,8 @@ impl ProxyState { &self, hostname: &str, path: &str, + ray_id: Id, + req_id: Id, method: &hyper::Method, port_type: PortType, headers: &hyper::HeaderMap, @@ -401,7 +399,7 @@ impl ProxyState { let res = timeout( default_timeout, - (self.routing_fn)(hostname_only, path, port_type, headers), + (self.routing_fn)(hostname_only, path, ray_id, req_id, port_type, headers), ) .await .map_err(|_| { @@ -433,7 +431,6 @@ impl ProxyState { target_host = %target.host, target_port = target.port, target_path = %target.path, - actor_id = ?target.actor_id, "Selected target for request" ); Ok(ResolveRouteOutput::Target(target.clone())) @@ -458,33 +455,28 @@ impl ProxyState { } #[tracing::instrument(skip_all)] - async fn get_middleware_config( - &self, - actor_id: &Id, - headers: &hyper::HeaderMap, - ) -> Result { + async fn get_middleware_config(&self, headers: &hyper::HeaderMap) -> Result { // Call the middleware function with a timeout let default_timeout = Duration::from_secs(5); // Default 5 seconds - let middleware_result = - timeout(default_timeout, (self.middleware_fn)(actor_id, headers)).await; + let middleware_result = timeout(default_timeout, (self.middleware_fn)(headers)).await; match middleware_result { Ok(result) => match result? { MiddlewareResponse::Ok(config) => Ok(config), MiddlewareResponse::NotFound => { - // Default values if middleware not found for this actor + // Default values Ok(MiddlewareConfig { rate_limit: RateLimitConfig { - requests: 100, // 100 requests - period: 60, // per 60 seconds + requests: 10000, // 10000 requests + period: 60, // per 60 seconds }, max_in_flight: MaxInFlightConfig { - amount: 20, // 20 concurrent requests + amount: 2000, // 2000 concurrent requests }, retry: RetryConfig { - max_attempts: 3, // 3 retry attempts - initial_interval: 100, // 100ms initial interval + max_attempts: 7, // 7 retry attempts + initial_interval: 150, // 150ms initial interval }, timeout: TimeoutConfig { request_timeout: 30, // 30 seconds for requests @@ -518,21 +510,12 @@ impl ProxyState { async fn check_rate_limit( &self, ip_addr: std::net::IpAddr, - actor_id: &Option, headers: &hyper::HeaderMap, ) -> Result { - let Some(actor_id) = *actor_id else { - // No rate limiting when actor_id is None - return Ok(true); - }; - - // Get actor-specific middleware config - let middleware_config = self.get_middleware_config(&actor_id, headers).await?; - - let cache_key = (actor_id, ip_addr); + let middleware_config = self.get_middleware_config(headers).await?; // Get existing limiter or create a new one - let limiter_arc = if let Some(existing_limiter) = self.rate_limiters.get(&cache_key).await { + let limiter_arc = if let Some(existing_limiter) = self.rate_limiters.get(&ip_addr).await { existing_limiter } else { let new_limiter = Arc::new(Mutex::new(RateLimiter::new( @@ -540,7 +523,7 @@ impl ProxyState { middleware_config.rate_limit.period, ))); self.rate_limiters - .insert(cache_key, new_limiter.clone()) + .insert(ip_addr, new_limiter.clone()) .await; metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.entry_count() as i64); new_limiter @@ -559,20 +542,15 @@ impl ProxyState { async fn acquire_in_flight( &self, ip_addr: std::net::IpAddr, - actor_id: &Option, headers: &hyper::HeaderMap, ) -> Result> { - // Check in-flight limit if actor_id is present - if let Some(actor_id) = *actor_id { - // Get actor-specific middleware config - let middleware_config = self.get_middleware_config(&actor_id, headers).await?; + let middleware_config = self.get_middleware_config(headers).await?; - let cache_key = (actor_id, ip_addr); + let cache_key = ip_addr; - // Get existing counter or create a new one - let counter_arc = if let Some(existing_counter) = - self.in_flight_counters.get(&cache_key).await - { + // Get existing counter or create a new one + let counter_arc = + if let Some(existing_counter) = self.in_flight_counters.get(&cache_key).await { existing_counter } else { let new_counter = Arc::new(Mutex::new(InFlightCounter::new( @@ -585,15 +563,14 @@ impl ProxyState { new_counter }; - // Try to acquire from the counter - let acquired = { - let mut counter = counter_arc.lock().await; - counter.try_acquire() - }; + // Try to acquire from the counter + let acquired = { + let mut counter = counter_arc.lock().await; + counter.try_acquire() + }; - if !acquired { - return Ok(None); // Rate limited - } + if !acquired { + return Ok(None); // Rate limited } // Generate unique request ID @@ -602,19 +579,10 @@ impl ProxyState { } #[tracing::instrument(skip_all)] - async fn release_in_flight( - &self, - ip_addr: std::net::IpAddr, - actor_id: &Option, - request_id: protocol::RequestId, - ) { - // Release in-flight counter if actor_id is present - if let Some(actor_id) = *actor_id { - let cache_key = (actor_id, ip_addr); - if let Some(counter_arc) = self.in_flight_counters.get(&cache_key).await { - let mut counter = counter_arc.lock().await; - counter.release(); - } + async fn release_in_flight(&self, ip_addr: std::net::IpAddr, request_id: protocol::RequestId) { + if let Some(counter_arc) = self.in_flight_counters.get(&ip_addr).await { + let mut counter = counter_arc.lock().await; + counter.release(); } // Release request ID @@ -703,8 +671,9 @@ impl ProxyService { async fn handle_request( &self, req: Request, + ray_id: Id, + req_id: Id, start_time: Instant, - request_context: &mut RequestContext, ) -> Result> { let host = req .headers() @@ -725,6 +694,8 @@ impl ProxyService { .resolve_route( host, &path, + ray_id, + req_id, req.method(), self.state.port_type.clone(), req.headers(), @@ -738,12 +709,6 @@ impl ProxyService { // Resolve target let target = target_res?; - let actor_id = if let ResolveRouteOutput::Target(target) = &target { - target.actor_id - } else { - None - }; - // Extract IP address from X-Forwarded-For header or fall back to remote_addr let client_ip = req .headers() @@ -759,11 +724,10 @@ impl ProxyService { // Apply rate limiting if !self .state - .check_rate_limit(client_ip, &actor_id, req.headers()) + .check_rate_limit(client_ip, req.headers()) .await? { return Err(errors::RateLimit { - actor_id, method: req.method().to_string(), path: path.clone(), ip: client_ip.to_string(), @@ -771,16 +735,15 @@ impl ProxyService { .build()); } - // Acquire in-flight limit and generate request ID + // Acquire in-flight limit and generate protocol request ID let request_id = match self .state - .acquire_in_flight(client_ip, &actor_id, req.headers()) + .acquire_in_flight(client_ip, req.headers()) .await? { Some(id) => id, None => { return Err(errors::RateLimit { - actor_id, method: req.method().to_string(), path: path.clone(), ip: client_ip.to_string(), @@ -793,16 +756,11 @@ impl ProxyService { metrics::PROXY_REQUEST_PENDING.inc(); metrics::PROXY_REQUEST_TOTAL.inc(); - // Update request context with target info - if let Some(actor_id) = actor_id { - request_context.service_actor_id = Some(actor_id); - } - let res = if hyper_tungstenite::is_upgrade_request(&req) { - self.handle_websocket_upgrade(req, target, request_context, client_ip, actor_id) + self.handle_websocket_upgrade(req, target, client_ip, ray_id, req_id) .await } else { - self.handle_http_request(req, target, request_context, client_ip, actor_id) + self.handle_http_request(req, target, client_ip, ray_id, req_id) .await }; @@ -823,9 +781,7 @@ impl ProxyService { let state_clone = self.state.clone(); tokio::spawn( async move { - state_clone - .release_in_flight(client_ip, &actor_id, request_id) - .await; + state_clone.release_in_flight(client_ip, request_id).await; } .instrument(tracing::info_span!("release_in_flight_task")), ); @@ -838,36 +794,11 @@ impl ProxyService { &self, req: Request, resolved_route: ResolveRouteOutput, - request_context: &mut RequestContext, client_ip: std::net::IpAddr, - actor_id: Option, + ray_id: Id, + req_id: Id, ) -> Result> { - // Get middleware config for this actor if it exists - let middleware_config = if let ResolveRouteOutput::Target(target) = &resolved_route - && let Some(actor_id) = &target.actor_id - { - self.state - .get_middleware_config(actor_id, req.headers()) - .await? - } else { - // Default middleware config for targets without actor_id - MiddlewareConfig { - rate_limit: RateLimitConfig { - requests: 100, // 100 requests - period: 60, // per 60 seconds - }, - max_in_flight: MaxInFlightConfig { - amount: 20, // 20 concurrent requests - }, - retry: RetryConfig { - max_attempts: 3, // 3 retry attempts - initial_interval: 100, // 100ms initial interval - }, - timeout: TimeoutConfig { - request_timeout: 30, // 30 seconds for requests - }, - } - }; + let middleware_config = self.state.get_middleware_config(req.headers()).await?; let host = req .headers() @@ -889,13 +820,6 @@ impl ProxyService { match resolved_route { ResolveRouteOutput::Target(mut target) => { - // Set service IP from target - if let Ok(target_ip) = - format!("{}:{}", target.host, target.port).parse::() - { - request_context.service_ip = Some(target_ip.ip()); - } - // Read the request body before proceeding with retries let (req_parts, body) = req.into_parts(); let req_body = match http_body_util::BodyExt::collect(body).await { @@ -906,9 +830,6 @@ impl ProxyService { } }; - // Set actual request body size in analytics - request_context.client_request_body_bytes = Some(req_body.len() as u64); - // Use a value-returning loop to handle both errors and successful responses let mut attempts = 0; while attempts < max_attempts { @@ -954,6 +875,8 @@ impl ProxyService { .resolve_route( &host, &path, + ray_id, + req_id, &req_parts.method, self.state.port_type.clone(), &req_parts.headers, @@ -981,9 +904,6 @@ impl ProxyService { // For streaming responses, pass through the body without buffering tracing::debug!("Detected streaming response, preserving stream"); - // We can't easily calculate response size for streaming, so set it to None - request_context.guard_response_body_bytes = None; - let streaming_body = ResponseBody::Incoming(body); return Ok(Response::from_parts(parts, streaming_body)); } else { @@ -993,10 +913,6 @@ impl ProxyService { Err(_) => Bytes::new(), }; - // Set actual response body size in analytics - request_context.guard_response_body_bytes = - Some(body_bytes.len() as u64); - let full_body = ResponseBody::Full(Full::new(body_bytes)); return Ok(Response::from_parts(parts, full_body)); } @@ -1029,6 +945,8 @@ impl ProxyService { .resolve_route( &host, &path, + ray_id, + req_id, &req_parts.method, self.state.port_type.clone(), &req_parts.headers, @@ -1059,13 +977,12 @@ impl ProxyService { // Acquire in-flight limit and generate request ID let request_id = match self .state - .acquire_in_flight(client_ip, &actor_id, &req_headers) + .acquire_in_flight(client_ip, &req_headers) .await? { Some(id) => id, None => { return Err(errors::RateLimit { - actor_id, method: req_method.to_string(), path: path.clone(), ip: client_ip.to_string(), @@ -1094,7 +1011,7 @@ impl ProxyService { attempts += 1; let res = handler - .handle_request(req_collected.clone(), request_context, request_id) + .handle_request(req_collected.clone(), ray_id, req_id, request_id) .await; if should_retry_request(&res) { // Request connect error, might retry @@ -1110,6 +1027,8 @@ impl ProxyService { .resolve_route( &host, &path, + ray_id, + req_id, req_collected.method(), self.state.port_type.clone(), &req_headers, @@ -1125,17 +1044,13 @@ impl ProxyService { } // Release in-flight counter and request ID before returning - self.state - .release_in_flight(client_ip, &actor_id, request_id) - .await; + self.state.release_in_flight(client_ip, request_id).await; return res; } // If we get here, all attempts failed // Release in-flight counter and request ID before returning error - self.state - .release_in_flight(client_ip, &actor_id, request_id) - .await; + self.state.release_in_flight(client_ip, request_id).await; return Err(errors::RetryAttemptsExceeded { attempts: max_attempts, } @@ -1193,9 +1108,9 @@ impl ProxyService { &self, req: Request, target: ResolveRouteOutput, - request_context: &mut RequestContext, client_ip: std::net::IpAddr, - actor_id: Option, + ray_id: Id, + req_id: Id, ) -> Result> { // Parsed for retries later let req_host = req @@ -1209,40 +1124,10 @@ impl ProxyService { .path_and_query() .map(|x| x.to_string()) .unwrap_or_else(|| req.uri().path().to_string()); - - // Capture headers and method before request is consumed let req_headers = req.headers().clone(); let req_method = req.method().clone(); - let ray_id = req.extensions().get::().map(|x| x.ray_id); - // Get middleware config for this actor if it exists - let middleware_config = match &actor_id { - Some(actor_id) => { - self.state - .get_middleware_config(actor_id, &req_headers) - .await? - } - None => { - // Default middleware config for targets without actor_id - tracing::debug!("Using default middleware config (no actor_id)"); - MiddlewareConfig { - rate_limit: RateLimitConfig { - requests: 10000, // 10000 requests - period: 60, // per 60 seconds - }, - max_in_flight: MaxInFlightConfig { - amount: 2000, // 2000 concurrent requests - }, - retry: RetryConfig { - max_attempts: 7, // 7 retry attempts - initial_interval: 150, // 150ms initial interval - }, - timeout: TimeoutConfig { - request_timeout: 30, // 30 seconds for requests - }, - } - } - }; + let middleware_config = self.state.get_middleware_config(req.headers()).await?; // Set up retry with backoff from middleware config let max_attempts = middleware_config.retry.max_attempts; @@ -1504,6 +1389,8 @@ impl ProxyService { .resolve_route( &req_host, &req_path, + ray_id, + req_id, &req_method, state.port_type.clone(), &req_headers, @@ -1856,31 +1743,27 @@ impl ProxyService { } ResolveRouteOutput::CustomServe(mut handler) => { tracing::debug!(%req_path, "Spawning task to handle WebSocket communication"); - let mut request_context = request_context.clone(); - let req_headers = req_headers.clone(); let state = self.state.clone(); + let req_headers = req_headers.clone(); let req_path = req_path.clone(); let req_host = req_host.clone(); let req_method = req_method.clone(); self.state.tasks.spawn( async move { - let request_id = match state - .acquire_in_flight(client_ip, &actor_id, &req_headers) - .await? - { - Some(id) => id, - None => { - return Err(errors::RateLimit { - actor_id, - method: req_method.to_string(), - path: req_path.clone(), - ip: client_ip.to_string(), + let request_id = + match state.acquire_in_flight(client_ip, &req_headers).await? { + Some(id) => id, + None => { + return Err(errors::RateLimit { + method: req_method.to_string(), + path: req_path.clone(), + ip: client_ip.to_string(), + } + .build() + .into()); } - .build() - .into()); - } - }; + }; let mut ws_hibernation_close = false; let mut after_hibernation = false; let mut attempts = 0u32; @@ -1895,7 +1778,8 @@ impl ProxyService { ws_handle.clone(), &req_headers, &req_path, - &mut request_context, + ray_id, + req_id, request_id, after_hibernation, ) @@ -1967,6 +1851,8 @@ impl ProxyService { let res = handler .handle_websocket_hibernation( ws_handle.clone(), + ray_id, + req_id, request_id, ) .await?; @@ -2025,6 +1911,8 @@ impl ProxyService { .resolve_route( &req_host, &req_path, + ray_id, + req_id, &req_method, state.port_type.clone(), &req_headers, @@ -2081,9 +1969,7 @@ impl ProxyService { } // Release in-flight counter and request ID when task completes - state - .release_in_flight(client_ip, &actor_id, request_id) - .await; + state.release_in_flight(client_ip, request_id).await; Ok(()) } @@ -2127,10 +2013,6 @@ impl ProxyService { .record("req_id", request_ids.req_id.to_string()) .record("ray_id", request_ids.ray_id.to_string()); - // Create request context for analytics tracking - let mut request_context = - RequestContext::new(self.state.clickhouse_inserter.clone(), request_ids); - // Extract request information for logging and analytics before consuming the request let incoming_ray_id = req .headers() @@ -2157,37 +2039,6 @@ impl ProxyService { .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()); - // Populate request context with available data - request_context.client_ip = Some(self.remote_addr.ip()); - request_context.client_request_host = Some(host.clone()); - request_context.client_request_method = Some(method.to_string()); - request_context.client_request_path = Some(req.uri().path().to_string()); - request_context.client_request_protocol = Some(format!("{:?}", req.version())); - request_context.client_request_scheme = - Some(req.uri().scheme_str().unwrap_or("http").to_string()); - request_context.client_request_uri = Some(path.clone()); - request_context.client_src_port = Some(self.remote_addr.port()); - - if let Some(referer) = req - .headers() - .get(hyper::header::REFERER) - .and_then(|h| h.to_str().ok()) - { - request_context.client_request_referer = Some(referer.to_string()); - } - - if let Some(ua) = &user_agent { - request_context.client_request_user_agent = Some(ua.clone()); - } - - if let Some(requested_with) = req - .headers() - .get("x-requested-with") - .and_then(|h| h.to_str().ok()) - { - request_context.client_x_requested_with = Some(requested_with.to_string()); - } - // TLS information would be set here if available (for HTTPS connections) // This requires TLS connection introspection and is marked for future enhancement @@ -2223,7 +2074,7 @@ impl ProxyService { // Process the request let mut res = match self - .handle_request(req, start_time, &mut request_context) + .handle_request(req, request_ids.ray_id, request_ids.req_id, start_time) .await { Ok(res) => res, @@ -2256,7 +2107,7 @@ impl ProxyService { return; } }; - let frame = err_to_close_frame(err, Some(request_ids.ray_id)); + let frame = err_to_close_frame(err, request_ids.ray_id); // Manual conversion to handle different tungstenite versions let code_num: u16 = frame.code.into(); @@ -2362,49 +2213,6 @@ impl ProxyService { let status = res.status().as_u16(); - // Update request context with response details - request_context.guard_response_status = Some(status); - request_context.service_response_status = Some(status); - - if let Some(content_type) = res - .headers() - .get(hyper::header::CONTENT_TYPE) - .and_then(|h| h.to_str().ok()) - { - request_context.guard_response_content_type = Some(content_type.to_string()); - } - - if let Some(expires) = res - .headers() - .get(hyper::header::EXPIRES) - .and_then(|h| h.to_str().ok()) - { - request_context.service_response_http_expires = Some(expires.to_string()); - } - - if let Some(last_modified) = res - .headers() - .get(hyper::header::LAST_MODIFIED) - .and_then(|h| h.to_str().ok()) - { - request_context.service_response_http_last_modified = Some(last_modified.to_string()); - } - - // Set timing information - request_context.service_response_duration_ms = - Some(start_time.elapsed().as_millis() as u32); - - // Insert analytics event asynchronously - let mut context_clone = request_context.clone(); - tokio::spawn( - async move { - if let Err(error) = context_clone.insert_event().await { - tracing::warn!(?error, "failed to insert guard analytics event"); - } - } - .instrument(tracing::info_span!("insert_event_task")), - ); - let content_length = res .headers() .get(hyper::header::CONTENT_LENGTH) @@ -2452,7 +2260,6 @@ impl ProxyServiceFactory { cache_key_fn: CacheKeyFn, middleware_fn: MiddlewareFn, port_type: PortType, - clickhouse_inserter: Option, ) -> Self { let state = Arc::new(ProxyState::new( config, @@ -2460,7 +2267,6 @@ impl ProxyServiceFactory { cache_key_fn, middleware_fn, port_type, - clickhouse_inserter, )); Self { state } } @@ -2527,8 +2333,7 @@ fn err_into_response(err: anyhow::Error) -> Result> { ("guard", "routing_error") => StatusCode::BAD_GATEWAY, ("guard", "request_timeout") => StatusCode::GATEWAY_TIMEOUT, ("guard", "retry_attempts_exceeded") => StatusCode::BAD_GATEWAY, - ("guard", "actor_not_found") => StatusCode::NOT_FOUND, - ("guard", "actor_destroyed") => StatusCode::NOT_FOUND, + ("actor", "not_found") => StatusCode::NOT_FOUND, ("guard", "service_unavailable") => StatusCode::SERVICE_UNAVAILABLE, ("guard", "actor_ready_timeout") => StatusCode::SERVICE_UNAVAILABLE, ("guard", "no_route") => StatusCode::NOT_FOUND, @@ -2597,7 +2402,7 @@ pub fn is_ws_hibernate(err: &anyhow::Error) -> bool { } } -fn err_to_close_frame(err: anyhow::Error, ray_id: Option) -> CloseFrame { +fn err_to_close_frame(err: anyhow::Error, ray_id: Id) -> CloseFrame { let rivet_err = err .chain() .find_map(|x| x.downcast_ref::()) @@ -2614,11 +2419,7 @@ fn err_to_close_frame(err: anyhow::Error, ray_id: Option) -> CloseFrame { _ => tracing::error!(?err, "websocket failed"), } - let reason = if let Some(ray_id) = ray_id { - format!("{}.{}#{}", rivet_err.group(), rivet_err.code(), ray_id) - } else { - format!("{}.{}", rivet_err.group(), rivet_err.code()) - }; + let reason = format!("{}.{}#{}", rivet_err.group(), rivet_err.code(), ray_id); // NOTE: reason cannot be more than 123 bytes as per the WS protocol let reason = rivet_util::safe_slice(&reason, 0, 123).into(); diff --git a/engine/packages/guard-core/src/request_context.rs b/engine/packages/guard-core/src/request_context.rs deleted file mode 100644 index c519649b5a..0000000000 --- a/engine/packages/guard-core/src/request_context.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::{net::IpAddr, time::SystemTime}; - -use anyhow::Result; -use rivet_api_builder::RequestIds; -use rivet_util::Id; - -use crate::analytics::GuardHttpRequest; - -// Properties not currently tracked but should be added in future iterations: -// - client_ssl_cipher: Requires TLS connection introspection -// - client_ssl_protocol: Requires TLS connection introspection -// - client_tcp_rtt_ms: Requires network-level measurements -// - client_request_bytes: Total request size including headers -// - service_dns_response_time_ms: Requires DNS timing instrumentation -// - service_ssl_protocol: Requires upstream TLS introspection -// - service_tcp_handshake_duration_ms: Requires connection-level timing -// - service_tls_handshake_duration_ms: Requires TLS handshake timing -// - service_request_header_send_duration_ms: Requires granular timing -// - service_response_header_receive_duration_ms: Requires granular timing -// - guard_response_bytes: Total response size including headers -// - guard_time_to_first_byte_ms: Requires granular timing -// - security_rule_id: Requires security/firewall rule integration - -#[derive(Clone)] -pub struct RequestContext { - // Request tracking data - // TODO: - // pub request_id: Id, - // pub ray_id: Id, - pub client_ip: Option, - pub client_request_body_bytes: Option, - pub client_request_host: Option, - pub client_request_method: Option, - pub client_request_path: Option, - pub client_request_protocol: Option, - pub client_request_referer: Option, - pub client_request_scheme: Option, - pub client_request_uri: Option, - pub client_request_user_agent: Option, - pub client_src_port: Option, - pub client_x_requested_with: Option, - - // Guard tracking data - // pub guard_datacenter_id: Option, - // pub guard_cluster_id: Option, - // pub guard_server_id: Option, - pub guard_end_timestamp: Option, - pub guard_response_body_bytes: Option, - pub guard_response_content_type: Option, - pub guard_response_status: Option, - pub guard_start_timestamp: SystemTime, - - // Service tracking data - pub service_ip: Option, - pub service_response_duration_ms: Option, - pub service_response_http_expires: Option, - pub service_response_http_last_modified: Option, - pub service_response_status: Option, - pub service_actor_id: Option, - - // ClickHouse inserter handle - clickhouse_inserter: Option, -} - -impl RequestContext { - pub fn new( - clickhouse_inserter: Option, - _request_ids: RequestIds, - ) -> Self { - Self { - client_ip: None, - client_request_body_bytes: None, - client_request_host: None, - client_request_method: None, - client_request_path: None, - client_request_protocol: None, - client_request_referer: None, - client_request_scheme: None, - client_request_uri: None, - client_request_user_agent: None, - client_src_port: None, - client_x_requested_with: None, - guard_end_timestamp: None, - guard_response_body_bytes: None, - guard_response_content_type: None, - guard_response_status: None, - guard_start_timestamp: SystemTime::now(), - service_ip: None, - service_response_duration_ms: None, - service_response_http_expires: None, - service_response_http_last_modified: None, - service_response_status: None, - service_actor_id: None, - clickhouse_inserter, - } - } - - // Finalize the request and insert analytics event - pub async fn insert_event(&mut self) -> Result<()> { - let Some(inserter) = &self.clickhouse_inserter else { - return Ok(()); // No inserter available - }; - - // Set end timestamp - self.guard_end_timestamp = Some(SystemTime::now()); - - // Convert IP addresses to strings for ClickHouse IPv4 type - let client_ip = match self.client_ip { - Some(IpAddr::V4(ip)) => ip.to_string(), - Some(IpAddr::V6(_)) => "0.0.0.0".to_string(), // Fallback for IPv6 addresses - None => "0.0.0.0".to_string(), // Default fallback - }; - - let service_ip = match self.service_ip { - Some(IpAddr::V4(ip)) => ip.to_string(), - Some(IpAddr::V6(_)) => "0.0.0.0".to_string(), // Fallback for IPv6 addresses - None => "127.0.0.1".to_string(), // Default fallback - }; - - // Convert SystemTime to nanoseconds since Unix epoch for ClickHouse DateTime64(9) - let guard_start_timestamp = self - .guard_start_timestamp - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_nanos() as u64; - - let guard_end_timestamp = self - .guard_end_timestamp - .unwrap_or_else(SystemTime::now) - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_nanos() as u64; - - // Build the analytics event inline with defaults for missing values - let analytics_event = GuardHttpRequest { - client_ip, - client_request_body_bytes: self.client_request_body_bytes.unwrap_or_default(), - client_request_host: self.client_request_host.clone().unwrap_or_default(), - client_request_method: self.client_request_method.clone().unwrap_or_default(), - client_request_path: self.client_request_path.clone().unwrap_or_default(), - client_request_protocol: self.client_request_protocol.clone().unwrap_or_default(), - client_request_referer: self.client_request_referer.clone().unwrap_or_default(), - client_request_scheme: self.client_request_scheme.clone().unwrap_or_default(), - client_request_uri: self.client_request_uri.clone().unwrap_or_default(), - client_request_user_agent: self.client_request_user_agent.clone().unwrap_or_default(), - client_src_port: self.client_src_port.unwrap_or_default(), - client_x_requested_with: self.client_x_requested_with.clone().unwrap_or_default(), - guard_end_timestamp, - guard_response_body_bytes: self.guard_response_body_bytes.unwrap_or_default(), - guard_response_content_type: self - .guard_response_content_type - .clone() - .unwrap_or_default(), - guard_response_status: self.guard_response_status.unwrap_or_default(), - guard_start_timestamp, - service_ip, - service_response_duration_ms: self.service_response_duration_ms.unwrap_or_default(), - service_response_http_expires: self - .service_response_http_expires - .clone() - .unwrap_or_default(), - service_response_http_last_modified: self - .service_response_http_last_modified - .clone() - .unwrap_or_default(), - service_response_status: self.service_response_status.unwrap_or_default(), - service_actor_id: self - .service_actor_id - .map(|x| x.to_string()) - .unwrap_or_default(), - }; - - // Insert the event asynchronously - inserter.insert("db_guard_analytics", "http_requests", analytics_event)?; - - Ok(()) - } -} - -impl std::fmt::Debug for RequestContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RequestContext").finish_non_exhaustive() - } -} diff --git a/engine/packages/guard-core/src/server.rs b/engine/packages/guard-core/src/server.rs index ef285cca7f..13e4a28efd 100644 --- a/engine/packages/guard-core/src/server.rs +++ b/engine/packages/guard-core/src/server.rs @@ -28,7 +28,6 @@ pub async fn run_server( cache_key_fn: CacheKeyFn, middleware_fn: MiddlewareFn, cert_resolver_fn: Option, - clickhouse_inserter: Option, ) -> Result<()> { // Set up HTTP server let http_addr: std::net::SocketAddr = (config.guard().host(), config.guard().port()).into(); @@ -38,7 +37,6 @@ pub async fn run_server( cache_key_fn.clone(), middleware_fn.clone(), crate::proxy_service::PortType::Http, - clickhouse_inserter.clone(), )); let http_listener = tokio::net::TcpListener::bind(http_addr).await?; @@ -53,7 +51,6 @@ pub async fn run_server( cache_key_fn.clone(), middleware_fn.clone(), crate::proxy_service::PortType::Https, - clickhouse_inserter.clone(), )); let listener = tokio::net::TcpListener::bind(https_addr).await?; diff --git a/engine/packages/guard/src/cache/mod.rs b/engine/packages/guard/src/cache/mod.rs index b4b0e48125..75a6514cfd 100644 --- a/engine/packages/guard/src/cache/mod.rs +++ b/engine/packages/guard/src/cache/mod.rs @@ -14,7 +14,7 @@ use crate::routing::X_RIVET_TARGET; /// Creates the main cache key function that handles all incoming requests #[tracing::instrument(skip_all)] -pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn { +pub fn create_cache_key_function() -> CacheKeyFn { Arc::new(move |hostname, path, method, _port_type, headers| { tracing::debug!("building cache key"); diff --git a/engine/packages/guard/src/lib.rs b/engine/packages/guard/src/lib.rs index aa4b8099d5..1a45bb3495 100644 --- a/engine/packages/guard/src/lib.rs +++ b/engine/packages/guard/src/lib.rs @@ -33,9 +33,9 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R shared_state.start().await?; // Create handlers - let routing_fn = routing::create_routing_function(ctx.clone(), shared_state.clone()); - let cache_key_fn = cache::create_cache_key_function(ctx.clone()); - let middleware_fn = middleware::create_middleware_function(ctx.clone()); + let routing_fn = routing::create_routing_function(&ctx, shared_state.clone()); + let cache_key_fn = cache::create_cache_key_function(); + let middleware_fn = middleware::create_middleware_function(); let cert_resolver = tls::create_cert_resolver(&ctx).await?; if let Some(_) = &cert_resolver { @@ -46,14 +46,12 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R // Start the server tracing::info!("starting proxy server"); - let clickhouse_inserter = ctx.clickhouse_inserter().ok(); rivet_guard_core::run_server( config, routing_fn, cache_key_fn, middleware_fn, cert_resolver, - clickhouse_inserter, ) .await } diff --git a/engine/packages/guard/src/middleware.rs b/engine/packages/guard/src/middleware.rs index 3f8567f7d3..c69ff69179 100644 --- a/engine/packages/guard/src/middleware.rs +++ b/engine/packages/guard/src/middleware.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use anyhow::*; -use gas::prelude::*; use rivet_guard_core::{ MiddlewareFn, proxy_service::{ @@ -11,10 +9,8 @@ use rivet_guard_core::{ }; /// Creates a middleware function that can use config and pools -pub fn create_middleware_function(ctx: StandaloneCtx) -> MiddlewareFn { - Arc::new(move |_actor_id: &Id, _headers: &hyper::HeaderMap| { - let _ctx = ctx.clone(); - +pub fn create_middleware_function() -> MiddlewareFn { + Arc::new(move |_headers: &hyper::HeaderMap| { Box::pin(async move { // In a real implementation, you would look up actor-specific middleware settings // For now, we'll just return a standard configuration @@ -30,8 +26,8 @@ pub fn create_middleware_function(ctx: StandaloneCtx) -> MiddlewareFn { amount: 2000, // 2000 concurrent requests }, retry: RetryConfig { - max_attempts: 7, - initial_interval: 150, + max_attempts: 7, // 7 retry attempts + initial_interval: 150, // 150ms initial interval }, timeout: TimeoutConfig { request_timeout: 30, // 30 seconds for requests diff --git a/engine/packages/guard/src/routing/api_public.rs b/engine/packages/guard/src/routing/api_public.rs index b1f701c4f3..fd9bd8baf9 100644 --- a/engine/packages/guard/src/routing/api_public.rs +++ b/engine/packages/guard/src/routing/api_public.rs @@ -6,8 +6,8 @@ use bytes::Bytes; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response}; +use rivet_guard_core::CustomServeTrait; use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput}; -use rivet_guard_core::{CustomServeTrait, request_context::RequestContext}; use rivet_runner_protocol as protocol; use tower::Service; @@ -20,7 +20,8 @@ impl CustomServeTrait for ApiPublicService { async fn handle_request( &self, req: Request>, - _request_context: &mut RequestContext, + _ray_id: Id, + _req_id: Id, _request_id: protocol::RequestId, ) -> Result> { // Clone the router to get a mutable service diff --git a/engine/packages/guard/src/routing/mod.rs b/engine/packages/guard/src/routing/mod.rs index a370666c44..ce13621e53 100644 --- a/engine/packages/guard/src/routing/mod.rs +++ b/engine/packages/guard/src/routing/mod.rs @@ -28,13 +28,16 @@ pub struct ActorPathInfo { /// Creates the main routing function that handles all incoming requests #[tracing::instrument(skip_all)] -pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> RoutingFn { +pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -> RoutingFn { + let ctx = ctx.clone(); Arc::new( move |hostname: &str, path: &str, + ray_id: Id, + req_id: Id, port_type: rivet_guard_core::proxy_service::PortType, headers: &hyper::HeaderMap| { - let ctx = ctx.clone(); + let ctx = ctx.with_ray(ray_id, req_id).unwrap(); let shared_state = shared_state.clone(); Box::pin( diff --git a/engine/packages/guard/src/routing/pegboard_gateway.rs b/engine/packages/guard/src/routing/pegboard_gateway.rs index 33d2f16bf0..99fa990858 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway.rs @@ -175,7 +175,6 @@ async fn route_request_inner( return Ok(Some(RoutingOutput::Route(RouteConfig { targets: vec![RouteTarget { - actor_id: Some(actor_id), host: peer_dc .proxy_url_host() .context("bad peer dc proxy url host")? diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index d0a5a14aa0..e3ad0ccc93 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -11,7 +11,6 @@ use rivet_guard_core::{ custom_serve::{CustomServeTrait, HibernationResult}, errors::{ServiceUnavailable, WebSocketServiceUnavailable}, proxy_service::{ResponseBody, is_ws_hibernate}, - request_context::RequestContext, websocket_handle::WebSocketReceiver, }; use rivet_runner_protocol::{self as protocol, PROTOCOL_MK1_VERSION}; @@ -85,9 +84,12 @@ impl CustomServeTrait for PegboardGateway { async fn handle_request( &self, req: Request>, - _request_context: &mut RequestContext, + ray_id: Id, + req_id: Id, request_id: protocol::RequestId, ) -> Result> { + let ctx = self.ctx.with_ray(ray_id, req_id)?; + // Use the actor ID from the gateway instance let actor_id = self.actor_id.to_string(); @@ -154,11 +156,10 @@ impl CustomServeTrait for PegboardGateway { .context("failed to read body")? .to_bytes(); - let udb = self.ctx.udb()?; + let udb = ctx.udb()?; let runner_id = self.runner_id; let (mut stopped_sub, runner_protocol_version) = tokio::try_join!( - self.ctx - .subscribe::(("actor_id", self.actor_id)), + ctx.subscribe::(("actor_id", self.actor_id)), // Read runner protocol version udb.run(|tx| async move { let tx = tx.with_subspace(pegboard::keys::subspace()); @@ -289,10 +290,13 @@ impl CustomServeTrait for PegboardGateway { client_ws: WebSocketHandle, headers: &hyper::HeaderMap, _path: &str, - _request_context: &mut RequestContext, + ray_id: Id, + req_id: Id, request_id: protocol::RequestId, after_hibernation: bool, ) -> Result> { + let ctx = self.ctx.with_ray(ray_id, req_id)?; + // Extract headers let mut request_headers = HashableMap::new(); for (name, value) in headers { @@ -301,11 +305,10 @@ impl CustomServeTrait for PegboardGateway { } } - let udb = self.ctx.udb()?; + let udb = ctx.udb()?; let runner_id = self.runner_id; let (mut stopped_sub, runner_protocol_version) = tokio::try_join!( - self.ctx - .subscribe::(("actor_id", self.actor_id)), + ctx.subscribe::(("actor_id", self.actor_id)), // Read runner protocol version udb.run(|tx| async move { let tx = tx.with_subspace(pegboard::keys::subspace()); @@ -452,7 +455,7 @@ impl CustomServeTrait for PegboardGateway { let keepalive = if can_hibernate { Some(tokio::spawn(keepalive_task::task( self.shared_state.clone(), - self.ctx.clone(), + ctx.clone(), self.actor_id, self.shared_state.gateway_id(), request_id, @@ -603,8 +606,12 @@ impl CustomServeTrait for PegboardGateway { async fn handle_websocket_hibernation( &self, client_ws: WebSocketHandle, + ray_id: Id, + req_id: Id, request_id: protocol::RequestId, ) -> Result { + let ctx = self.ctx.with_ray(ray_id, req_id)?; + // Immediately rewake if we have pending messages if self .shared_state @@ -620,7 +627,7 @@ impl CustomServeTrait for PegboardGateway { let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(()); let keepalive_handle = tokio::spawn(keepalive_task::task( self.shared_state.clone(), - self.ctx.clone(), + ctx.clone(), self.actor_id, self.shared_state.gateway_id(), request_id, @@ -636,13 +643,12 @@ impl CustomServeTrait for PegboardGateway { Ok(HibernationResult::Continue) => {} Ok(HibernationResult::Close) | Err(_) => { // No longer an active hibernating request, delete entry - self.ctx - .op(pegboard::ops::actor::hibernating_request::delete::Input { - actor_id: self.actor_id, - gateway_id: self.shared_state.gateway_id(), - request_id, - }) - .await?; + ctx.op(pegboard::ops::actor::hibernating_request::delete::Input { + actor_id: self.actor_id, + gateway_id: self.shared_state.gateway_id(), + request_id, + }) + .await?; } } diff --git a/engine/packages/pegboard-runner/src/lib.rs b/engine/packages/pegboard-runner/src/lib.rs index 8a550fcfce..2d4748cfc0 100644 --- a/engine/packages/pegboard-runner/src/lib.rs +++ b/engine/packages/pegboard-runner/src/lib.rs @@ -7,7 +7,6 @@ use hyper::{Response, StatusCode}; use pegboard::ops::runner::update_alloc_idx::Action; use rivet_guard_core::{ WebSocketHandle, custom_serve::CustomServeTrait, proxy_service::ResponseBody, - request_context::RequestContext, }; use rivet_runner_protocol as protocol; use tokio::sync::watch; @@ -46,7 +45,8 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { async fn handle_request( &self, _req: hyper::Request>, - _request_context: &mut RequestContext, + _ray_id: Id, + _req_id: Id, _request_id: protocol::RequestId, ) -> Result> { // Pegboard runner ws doesn't handle regular HTTP requests @@ -66,12 +66,15 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { ws_handle: WebSocketHandle, _headers: &hyper::HeaderMap, path: &str, - _request_context: &mut RequestContext, - _unique_request_id: protocol::RequestId, + ray_id: Id, + req_id: Id, + _request_id: protocol::RequestId, _after_hibernation: bool, ) -> Result> { + let ctx = self.ctx.with_ray(ray_id, req_id)?; + // Get UPS - let ups = self.ctx.ups().context("failed to get UPS instance")?; + let ups = ctx.ups().context("failed to get UPS instance")?; // Parse URL to extract parameters let url = url::Url::parse(&format!("ws://placeholder/{path}")) @@ -82,7 +85,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { tracing::debug!(?path, "tunnel ws connection established"); // Create connection - let conn = conn::init_conn(&self.ctx, ws_handle.clone(), url_data) + let conn = conn::init_conn(&ctx, ws_handle.clone(), url_data) .await .context("failed to initialize runner connection")?; @@ -145,7 +148,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { let (ping_abort_tx, ping_abort_rx) = watch::channel(()); let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task( - self.ctx.clone(), + ctx.clone(), conn.clone(), sub, eviction_sub, @@ -153,7 +156,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { )); let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task( - self.ctx.clone(), + ctx.clone(), conn.clone(), ws_handle.recv(), eviction_sub2, @@ -161,11 +164,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { )); // Update pings - let ping = tokio::spawn(ping_task::task( - self.ctx.clone(), - conn.clone(), - ping_abort_rx, - )); + let ping = tokio::spawn(ping_task::task(ctx.clone(), conn.clone(), ping_abort_rx)); let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone(); let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone(); let ping_abort_tx2 = ping_abort_tx.clone();