diff --git a/Cargo.lock b/Cargo.lock index 42dc4f8cf8..06f0628d02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3352,6 +3352,7 @@ dependencies = [ "gasoline", "namespace", "pegboard", + "reqwest", "reqwest-eventsource", "rivet-config", "rivet-runner-protocol", @@ -4370,6 +4371,7 @@ dependencies = [ "pegboard-gateway", "pegboard-runner", "regex", + "rivet-api-builder", "rivet-api-public", "rivet-cache", "rivet-config", diff --git a/out/openapi.json b/out/openapi.json index 176eb09c96..2b39a4fb33 100644 --- a/out/openapi.json +++ b/out/openapi.json @@ -1114,6 +1114,7 @@ "type": "object", "required": [ "url", + "headers", "request_lifespan", "slots_per_runner", "min_runners", @@ -1121,6 +1122,15 @@ "runners_margin" ], "properties": { + "headers": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "propertyNames": { + "type": "string" + } + }, "max_runners": { "type": "integer", "format": "int32", @@ -1199,6 +1209,7 @@ "type": "object", "required": [ "url", + "headers", "request_lifespan", "slots_per_runner", "min_runners", @@ -1206,6 +1217,15 @@ "runners_margin" ], "properties": { + "headers": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "propertyNames": { + "type": "string" + } + }, "max_runners": { "type": "integer", "format": "int32", diff --git a/packages/common/api-types/src/runners/list.rs b/packages/common/api-types/src/runners/list.rs index 1dc80a3c63..0b95489915 100644 --- a/packages/common/api-types/src/runners/list.rs +++ b/packages/common/api-types/src/runners/list.rs @@ -1,4 +1,3 @@ -use rivet_util::Id; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToSchema}; diff --git a/packages/common/config/src/config/auth.rs b/packages/common/config/src/config/auth.rs new file mode 100644 index 0000000000..931f493cdd --- /dev/null +++ b/packages/common/config/src/config/auth.rs @@ -0,0 +1,16 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct Auth { + pub admin_token: String, +} + +impl Default for Auth { + fn default() -> Self { + Auth { + admin_token: "admin".to_string(), + } + } +} diff --git a/packages/common/config/src/config/mod.rs b/packages/common/config/src/config/mod.rs index 71dc02792c..667e05f2da 100644 --- a/packages/common/config/src/config/mod.rs +++ b/packages/common/config/src/config/mod.rs @@ -5,6 +5,7 @@ use std::sync::LazyLock; pub mod api_peer; pub mod api_public; +pub mod auth; pub mod cache; pub mod clickhouse; pub mod db; @@ -17,6 +18,7 @@ pub mod vector; pub use api_peer::*; pub use api_public::*; +pub use auth::*; pub use cache::*; pub use clickhouse::*; pub use db::Database; @@ -58,6 +60,9 @@ pub use vector::*; #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] #[serde(rename_all = "snake_case", deny_unknown_fields)] pub struct Root { + #[serde(default)] + pub auth: Option, + #[serde(default)] pub guard: Option, @@ -95,6 +100,7 @@ pub struct Root { impl Default for Root { fn default() -> Self { Root { + auth: None, guard: None, api_public: None, api_peer: None, diff --git a/packages/common/types/src/namespaces.rs b/packages/common/types/src/namespaces.rs index 70c2dfc8e4..5dd220e655 100644 --- a/packages/common/types/src/namespaces.rs +++ b/packages/common/types/src/namespaces.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use gas::prelude::*; use utoipa::ToSchema; @@ -9,11 +11,12 @@ pub struct Namespace { pub create_ts: i64, } -#[derive(Debug, Clone, Serialize, Deserialize, Hash, ToSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] #[serde(rename_all = "snake_case")] pub enum RunnerConfig { Serverless { url: String, + headers: HashMap, /// Seconds. request_lifespan: u32, slots_per_runner: u32, @@ -28,6 +31,7 @@ impl From for rivet_data::generated::namespace_runner_config_v1::D match value { RunnerConfig::Serverless { url, + headers, request_lifespan, slots_per_runner, min_runners, @@ -36,6 +40,7 @@ impl From for rivet_data::generated::namespace_runner_config_v1::D } => rivet_data::generated::namespace_runner_config_v1::Data::Serverless( rivet_data::generated::namespace_runner_config_v1::Serverless { url, + headers: headers.into(), request_lifespan, slots_per_runner, min_runners, @@ -53,6 +58,7 @@ impl From for RunnerCon rivet_data::generated::namespace_runner_config_v1::Data::Serverless(o) => { RunnerConfig::Serverless { url: o.url, + headers: o.headers.into(), request_lifespan: o.request_lifespan, slots_per_runner: o.slots_per_runner, min_runners: o.min_runners, diff --git a/packages/core/api-peer/src/runner_configs.rs b/packages/core/api-peer/src/runner_configs.rs index 838792b5ee..31f38a0fd0 100644 --- a/packages/core/api-peer/src/runner_configs.rs +++ b/packages/core/api-peer/src/runner_configs.rs @@ -4,7 +4,6 @@ use anyhow::Result; use namespace::utils::runner_config_variant; use rivet_api_builder::ApiCtx; use rivet_api_types::pagination::Pagination; -use rivet_util::Id; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToSchema}; diff --git a/packages/core/api-public/src/actors/create.rs b/packages/core/api-public/src/actors/create.rs index 5ac97bc45d..3d13b0f631 100644 --- a/packages/core/api-public/src/actors/create.rs +++ b/packages/core/api-public/src/actors/create.rs @@ -4,13 +4,14 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_api_types::actors::create::{CreateRequest, CreateResponse}; use rivet_api_util::request_remote_datacenter; -use rivet_types::actors::CrashPolicy; use serde::{Deserialize, Serialize}; use utoipa::IntoParams; +use crate::ctx::ApiCtx; + #[derive(Debug, Serialize, Deserialize, IntoParams)] #[serde(deny_unknown_fields)] #[into_params(parameter_in = Query)] @@ -63,6 +64,8 @@ async fn create_inner( query: CreateQuery, body: CreateRequest, ) -> Result { + ctx.skip_auth(); + // Determine which datacenter to create the actor in let target_dc_label = if let Some(dc_name) = &query.datacenter { ctx.config() @@ -78,7 +81,7 @@ async fn create_inner( }; if target_dc_label == ctx.config().dc_label() { - rivet_api_peer::actors::create::create(ctx, (), query, body).await + rivet_api_peer::actors::create::create(ctx.into(), (), query, body).await } else { request_remote_datacenter::( ctx.config(), diff --git a/packages/core/api-public/src/actors/delete.rs b/packages/core/api-public/src/actors/delete.rs index 12417a9165..922a14caf0 100644 --- a/packages/core/api-public/src/actors/delete.rs +++ b/packages/core/api-public/src/actors/delete.rs @@ -4,12 +4,14 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_api_util::request_remote_datacenter_raw; use rivet_util::Id; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToSchema}; +use crate::ctx::ApiCtx; + #[derive(Debug, Deserialize, Serialize, IntoParams)] #[serde(deny_unknown_fields)] #[into_params(parameter_in = Query)] @@ -62,6 +64,8 @@ async fn delete_inner( path: DeletePath, query: DeleteQuery, ) -> Result { + ctx.auth().await?; + if path.actor_id.label() == ctx.config().dc_label() { let peer_path = rivet_api_peer::actors::delete::DeletePath { actor_id: path.actor_id, @@ -69,7 +73,7 @@ async fn delete_inner( let peer_query = rivet_api_peer::actors::delete::DeleteQuery { namespace: query.namespace, }; - let res = rivet_api_peer::actors::delete::delete(ctx, peer_path, peer_query).await?; + let res = rivet_api_peer::actors::delete::delete(ctx.into(), peer_path, peer_query).await?; Ok(Json(res).into_response()) } else { diff --git a/packages/core/api-public/src/actors/get_or_create.rs b/packages/core/api-public/src/actors/get_or_create.rs index 290e69a019..09690fbe8a 100644 --- a/packages/core/api-public/src/actors/get_or_create.rs +++ b/packages/core/api-public/src/actors/get_or_create.rs @@ -4,13 +4,14 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_types::actors::CrashPolicy; use rivet_util::Id; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToSchema}; use crate::actors::utils; +use crate::ctx::ApiCtx; use crate::errors; #[derive(Debug, Deserialize, IntoParams)] @@ -91,6 +92,8 @@ async fn get_or_create_inner( query: GetOrCreateQuery, body: GetOrCreateRequest, ) -> Result { + ctx.skip_auth(); + // Resolve namespace let namespace = ctx .op(namespace::ops::resolve_for_name_global::Input { diff --git a/packages/core/api-public/src/actors/list.rs b/packages/core/api-public/src/actors/list.rs index e597bbc959..f35f2f696d 100644 --- a/packages/core/api-public/src/actors/list.rs +++ b/packages/core/api-public/src/actors/list.rs @@ -4,13 +4,13 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_api_types::pagination::Pagination; use rivet_api_util::fanout_to_datacenters; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToSchema}; -use crate::{actors::utils::fetch_actors_by_ids, errors}; +use crate::{actors::utils::fetch_actors_by_ids, ctx::ApiCtx, errors}; #[derive(Debug, Serialize, Deserialize, Clone, IntoParams)] #[serde(deny_unknown_fields)] @@ -76,6 +76,8 @@ pub async fn list( } async fn list_inner(ctx: ApiCtx, headers: HeaderMap, query: ListQuery) -> Result { + ctx.skip_auth(); + // Parse query let actor_ids = query.actor_ids.as_ref().map(|x| { x.split(',') @@ -221,7 +223,7 @@ async fn list_inner(ctx: ApiCtx, headers: HeaderMap, query: ListQuery) -> Result _, Vec, >( - ctx, + ctx.into(), headers, "/actors", peer_query, diff --git a/packages/core/api-public/src/actors/list_names.rs b/packages/core/api-public/src/actors/list_names.rs index 4480ceafa1..15342d9e82 100644 --- a/packages/core/api-public/src/actors/list_names.rs +++ b/packages/core/api-public/src/actors/list_names.rs @@ -4,11 +4,13 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_api_types::{actors::list_names::*, pagination::Pagination}; use rivet_api_util::fanout_to_datacenters; use rivet_types::actors::ActorName; +use crate::ctx::ApiCtx; + /// ## Datacenter Round Trips /// /// 2 round trips: @@ -39,6 +41,8 @@ async fn list_names_inner( headers: HeaderMap, query: ListNamesQuery, ) -> Result { + ctx.auth().await?; + // Prepare peer query for local handler let peer_query = ListNamesQuery { namespace: query.namespace.clone(), @@ -49,7 +53,7 @@ async fn list_names_inner( // Fanout to all datacenters let mut all_names = fanout_to_datacenters::>( - ctx, + ctx.into(), headers, "/actors/names", peer_query, diff --git a/packages/core/api-public/src/actors/utils.rs b/packages/core/api-public/src/actors/utils.rs index e98bd4815a..840a83d6c6 100644 --- a/packages/core/api-public/src/actors/utils.rs +++ b/packages/core/api-public/src/actors/utils.rs @@ -23,7 +23,7 @@ pub async fn fetch_actor_by_id( if actor_id.label() == ctx.config().dc_label() { // Local datacenter - use peer API directly - let res = rivet_api_peer::actors::list::list(ctx.clone(), (), list_query).await?; + let res = rivet_api_peer::actors::list::list(ctx.clone().into(), (), list_query).await?; let actor = res .actors .into_iter() @@ -105,7 +105,7 @@ pub async fn fetch_actors_by_ids( if dc_label == ctx.config().dc_label() { // Local datacenter - use peer API directly - let res = rivet_api_peer::actors::list::list(ctx, (), peer_query).await?; + let res = rivet_api_peer::actors::list::list(ctx.into(), (), peer_query).await?; Ok::, anyhow::Error>(res.actors) } else { // Remote datacenter - make HTTP request diff --git a/packages/core/api-public/src/ctx.rs b/packages/core/api-public/src/ctx.rs new file mode 100644 index 0000000000..6a2439802c --- /dev/null +++ b/packages/core/api-public/src/ctx.rs @@ -0,0 +1,70 @@ +use std::{ + ops::Deref, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, +}; + +use anyhow::Result; + +#[derive(Clone)] +pub struct ApiCtx { + inner: rivet_api_builder::ApiCtx, + token: Option, + authentication_handled: Arc, +} + +impl ApiCtx { + pub fn new(inner: rivet_api_builder::ApiCtx, token: Option) -> Self { + ApiCtx { + inner, + token, + authentication_handled: Arc::new(AtomicBool::new(false)), + } + } + + pub async fn auth(&self) -> Result<()> { + let Some(auth) = &self.config().auth else { + return Ok(()); + }; + + self.authentication_handled.store(true, Ordering::Relaxed); + + if self.token.as_ref() == Some(&auth.admin_token) { + Ok(()) + } else { + Err(rivet_api_builder::ApiForbidden.build()) + } + } + + pub fn skip_auth(&self) { + self.authentication_handled.store(true, Ordering::Relaxed); + } + + pub fn is_auth_handled(&self) -> bool { + if self.config().auth.is_none() { + return true; + } + + self.authentication_handled.load(Ordering::Relaxed) + } + + pub fn token(&self) -> Option<&str> { + self.token.as_deref() + } +} + +impl Deref for ApiCtx { + type Target = rivet_api_builder::ApiCtx; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl From for rivet_api_builder::ApiCtx { + fn from(value: ApiCtx) -> rivet_api_builder::ApiCtx { + value.inner + } +} diff --git a/packages/core/api-public/src/datacenters.rs b/packages/core/api-public/src/datacenters.rs index f3a507f826..d8a1f9544b 100644 --- a/packages/core/api-public/src/datacenters.rs +++ b/packages/core/api-public/src/datacenters.rs @@ -1,8 +1,14 @@ use anyhow::Result; -use rivet_api_builder::ApiCtx; +use axum::{ + extract::Extension, + response::{IntoResponse, Json, Response}, +}; +use rivet_api_builder::ApiError; use rivet_api_types::{datacenters::list::*, pagination::Pagination}; use rivet_types::datacenters::Datacenter; +use crate::ctx::ApiCtx; + #[utoipa::path( get, operation_id = "datacenters_list", @@ -11,7 +17,16 @@ use rivet_types::datacenters::Datacenter; (status = 200, body = ListResponse), ), )] -pub async fn list(ctx: ApiCtx, _path: (), _query: ()) -> Result { +pub async fn list(Extension(ctx): Extension) -> Response { + match list_inner(ctx).await { + Ok(response) => Json(response).into_response(), + Err(err) => ApiError::from(err).into_response(), + } +} + +async fn list_inner(ctx: ApiCtx) -> Result { + ctx.auth().await?; + Ok(ListResponse { datacenters: ctx .config() diff --git a/packages/core/api-public/src/lib.rs b/packages/core/api-public/src/lib.rs index 3d8a55bfb6..1ca95a4b40 100644 --- a/packages/core/api-public/src/lib.rs +++ b/packages/core/api-public/src/lib.rs @@ -1,4 +1,5 @@ pub mod actors; +pub mod ctx; pub mod datacenters; mod errors; pub mod namespaces; diff --git a/packages/core/api-public/src/namespaces.rs b/packages/core/api-public/src/namespaces.rs index 62d8276724..367621e4e2 100644 --- a/packages/core/api-public/src/namespaces.rs +++ b/packages/core/api-public/src/namespaces.rs @@ -1,14 +1,16 @@ use anyhow::Result; use axum::{ - extract::{Extension, Path, Query}, + extract::{Extension, Query}, http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_api_peer::namespaces::*; use rivet_api_types::namespaces::list::*; use rivet_api_util::request_remote_datacenter; +use crate::ctx::ApiCtx; + #[utoipa::path( get, operation_id = "namespaces_list", @@ -30,8 +32,10 @@ pub async fn list( } async fn list_inner(ctx: ApiCtx, headers: HeaderMap, query: ListQuery) -> Result { + ctx.auth().await?; + if ctx.config().is_leader() { - rivet_api_peer::namespaces::list(ctx, (), query).await + rivet_api_peer::namespaces::list(ctx.into(), (), query).await } else { let leader_dc = ctx.config().leader_dc()?; request_remote_datacenter::( @@ -72,8 +76,10 @@ async fn create_inner( headers: HeaderMap, body: CreateRequest, ) -> Result { + ctx.auth().await?; + if ctx.config().is_leader() { - rivet_api_peer::namespaces::create(ctx, (), (), body).await + rivet_api_peer::namespaces::create(ctx.into(), (), (), body).await } else { let leader_dc = ctx.config().leader_dc()?; request_remote_datacenter::( diff --git a/packages/core/api-public/src/router.rs b/packages/core/api-public/src/router.rs index 85c45af3d2..a59de53cb7 100644 --- a/packages/core/api-public/src/router.rs +++ b/packages/core/api-public/src/router.rs @@ -1,8 +1,13 @@ -use axum::response::Redirect; -use rivet_api_builder::{create_router, wrappers::get}; +use axum::{ + extract::Request, + middleware::{self, Next}, + response::{Redirect, Response}, +}; +use reqwest::header::{AUTHORIZATION, HeaderMap}; +use rivet_api_builder::create_router; use utoipa::OpenApi; -use crate::{actors, datacenters, namespaces, runner_configs, runners, ui}; +use crate::{actors, ctx, datacenters, namespaces, runner_configs, runners, ui}; #[derive(OpenApi)] #[openapi(paths( @@ -66,11 +71,50 @@ pub async fn router( .route("/runners", axum::routing::get(runners::list)) .route("/runners/names", axum::routing::get(runners::list_names)) // MARK: Datacenters - .route("/datacenters", get(datacenters::list)) + .route("/datacenters", axum::routing::get(datacenters::list)) // MARK: UI .route("/ui", axum::routing::get(ui::serve_index)) .route("/ui/", axum::routing::get(ui::serve_index)) .route("/ui/{*path}", axum::routing::get(ui::serve_ui)) + // MARK: Middleware (must go after all routes) + .layer(middleware::from_fn(auth_middleware)) }) .await } + +/// Middleware to wrap ApiCtx with auth handling capabilities and to throw an error if auth was not explicitly +// handled in an endpoint +async fn auth_middleware( + headers: HeaderMap, + mut req: Request, + next: Next, +) -> std::result::Result { + let ctx = req + .extensions() + .get::() + .ok_or_else(|| "ctx should exist".to_string())?; + + // Extract token + let token = headers + .get(AUTHORIZATION) + .and_then(|x| x.to_str().ok().and_then(|x| x.strip_prefix("Bearer "))) + .map(|x| x.to_string()); + + // Insert the new ApiCtx into request extensions + let ctx = ctx::ApiCtx::new(ctx.clone(), token); + req.extensions_mut().insert(ctx.clone()); + + let path = req.uri().path().to_string(); + + // Run endpoint + let res = next.run(req).await; + + // Verify auth was handled + if !ctx.is_auth_handled() { + return Err(format!( + "developer error: must explicitly handle auth in all endpoints (path: {path})" + )); + } + + Ok(res) +} diff --git a/packages/core/api-public/src/runner_configs.rs b/packages/core/api-public/src/runner_configs.rs index 1f51b6d9d4..d83da137f4 100644 --- a/packages/core/api-public/src/runner_configs.rs +++ b/packages/core/api-public/src/runner_configs.rs @@ -4,12 +4,13 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; -use rivet_util::Id; +use rivet_api_builder::ApiError; use rivet_api_peer::runner_configs::*; use rivet_api_util::request_remote_datacenter; +use crate::ctx::ApiCtx; + #[utoipa::path( get, operation_id = "runner_configs_list", @@ -39,8 +40,10 @@ async fn list_inner( path: ListPath, query: ListQuery, ) -> Result { + ctx.auth().await?; + if ctx.config().is_leader() { - rivet_api_peer::runner_configs::list(ctx, path, query).await + rivet_api_peer::runner_configs::list(ctx.into(), path, query).await } else { let leader_dc = ctx.config().leader_dc()?; request_remote_datacenter::( @@ -89,8 +92,10 @@ async fn upsert_inner( query: UpsertQuery, body: UpsertRequest, ) -> Result { + ctx.auth().await?; + if ctx.config().is_leader() { - rivet_api_peer::runner_configs::upsert(ctx, path, query, body).await + rivet_api_peer::runner_configs::upsert(ctx.into(), path, query, body).await } else { let leader_dc = ctx.config().leader_dc()?; request_remote_datacenter::( @@ -136,8 +141,10 @@ async fn delete_inner( path: DeletePath, query: DeleteQuery, ) -> Result { + ctx.auth().await?; + if ctx.config().is_leader() { - rivet_api_peer::runner_configs::delete(ctx, path, query).await + rivet_api_peer::runner_configs::delete(ctx.into(), path, query).await } else { let leader_dc = ctx.config().leader_dc()?; request_remote_datacenter::( diff --git a/packages/core/api-public/src/runners.rs b/packages/core/api-public/src/runners.rs index 364860fbe4..c5ef3b744a 100644 --- a/packages/core/api-public/src/runners.rs +++ b/packages/core/api-public/src/runners.rs @@ -4,12 +4,14 @@ use axum::{ http::HeaderMap, response::{IntoResponse, Json, Response}, }; -use rivet_api_builder::{ApiCtx, ApiError}; +use rivet_api_builder::ApiError; use rivet_api_types::{pagination::Pagination, runners::list::*}; use rivet_api_util::fanout_to_datacenters; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToSchema}; +use crate::ctx::ApiCtx; + #[utoipa::path( get, operation_id = "runners_list", @@ -31,10 +33,12 @@ pub async fn list( } async fn list_inner(ctx: ApiCtx, headers: HeaderMap, query: ListQuery) -> Result { + ctx.auth().await?; + // Fanout to all datacenters let mut runners = fanout_to_datacenters::>( - ctx, + ctx.into(), headers, "/runners", query.clone(), @@ -105,6 +109,8 @@ async fn list_names_inner( headers: HeaderMap, query: ListNamesQuery, ) -> Result { + ctx.auth().await?; + // Prepare peer query for local handler let peer_query = rivet_api_peer::runners::ListNamesQuery { namespace: query.namespace.clone(), @@ -121,7 +127,7 @@ async fn list_names_inner( _, Vec, >( - ctx, + ctx.into(), headers, "/runners/names", peer_query, diff --git a/packages/core/guard/server/Cargo.toml b/packages/core/guard/server/Cargo.toml index ffb414f6c9..e5832eb483 100644 --- a/packages/core/guard/server/Cargo.toml +++ b/packages/core/guard/server/Cargo.toml @@ -27,6 +27,7 @@ pegboard-gateway.workspace = true pegboard.workspace = true pegboard-runner.workspace = true regex.workspace = true +rivet-api-builder.workspace = true rivet-api-public.workspace = true rivet-cache.workspace = true rivet-config.workspace = true diff --git a/packages/core/guard/server/src/routing/api_public.rs b/packages/core/guard/server/src/routing/api_public.rs index a762aca276..43415122da 100644 --- a/packages/core/guard/server/src/routing/api_public.rs +++ b/packages/core/guard/server/src/routing/api_public.rs @@ -6,7 +6,6 @@ use bytes::Bytes; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response}; -use hyper_tungstenite::HyperWebsocket; use rivet_guard_core::WebSocketHandle; use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput}; use rivet_guard_core::{CustomServeTrait, request_context::RequestContext}; diff --git a/packages/core/guard/server/src/routing/mod.rs b/packages/core/guard/server/src/routing/mod.rs index 02675ae8c9..ac61338bd1 100644 --- a/packages/core/guard/server/src/routing/mod.rs +++ b/packages/core/guard/server/src/routing/mod.rs @@ -12,6 +12,7 @@ pub mod pegboard_gateway; mod runner; pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target"); +pub(crate) const X_RIVET_TOKEN: HeaderName = HeaderName::from_static("x-rivet-token"); pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol"); pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target."; @@ -62,7 +63,7 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> // Read target if let Some(target) = target { if let Some(routing_output) = - runner::route_request(&ctx, target, host, path).await? + runner::route_request(&ctx, target, host, path, headers).await? { return Ok(routing_output); } diff --git a/packages/core/guard/server/src/routing/runner.rs b/packages/core/guard/server/src/routing/runner.rs index 4325229b22..be1a5a628d 100644 --- a/packages/core/guard/server/src/routing/runner.rs +++ b/packages/core/guard/server/src/routing/runner.rs @@ -1,20 +1,40 @@ use anyhow::*; use gas::prelude::*; -use rivet_guard_core::proxy_service::{RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout}; +use rivet_guard_core::proxy_service::RoutingOutput; use std::sync::Arc; +use super::X_RIVET_TOKEN; + /// Route requests to the API service #[tracing::instrument(skip_all)] pub async fn route_request( ctx: &StandaloneCtx, target: &str, _host: &str, - path: &str, + _path: &str, + headers: &hyper::HeaderMap, ) -> Result> { if target != "runner" { return Ok(None); } + // Check auth (if enabled) + if let Some(auth) = &ctx.config().auth { + let token = headers + .get(X_RIVET_TOKEN) + .and_then(|x| x.to_str().ok()) + .ok_or_else(|| { + crate::errors::MissingHeader { + header: X_RIVET_TOKEN.to_string(), + } + .build() + })?; + + if token != auth.admin_token { + return Err(rivet_api_builder::ApiForbidden.build()); + } + } + let tunnel = pegboard_runner::PegboardRunnerWsCustomServe::new(ctx.clone()); Ok(Some(RoutingOutput::CustomServe(Arc::new(tunnel)))) } diff --git a/packages/core/pegboard-serverless/Cargo.toml b/packages/core/pegboard-serverless/Cargo.toml index 44b97cbea1..e93ccffd84 100644 --- a/packages/core/pegboard-serverless/Cargo.toml +++ b/packages/core/pegboard-serverless/Cargo.toml @@ -10,6 +10,7 @@ anyhow.workspace = true epoxy.workspace = true gas.workspace = true reqwest-eventsource.workspace = true +reqwest.workspace = true rivet-config.workspace = true rivet-runner-protocol.workspace = true rivet-types.workspace = true diff --git a/packages/core/pegboard-serverless/src/lib.rs b/packages/core/pegboard-serverless/src/lib.rs index baf1ff1671..ac54c2edd2 100644 --- a/packages/core/pegboard-serverless/src/lib.rs +++ b/packages/core/pegboard-serverless/src/lib.rs @@ -10,6 +10,7 @@ use anyhow::Result; use futures_util::{StreamExt, TryStreamExt}; use gas::prelude::*; use pegboard::keys; +use reqwest::header::{HeaderName, HeaderValue}; use reqwest_eventsource as sse; use rivet_runner_protocol as protocol; use rivet_types::namespaces::RunnerConfig; @@ -102,6 +103,7 @@ async fn tick( let RunnerConfig::Serverless { url, + headers, request_lifespan, slots_per_runner, min_runners, @@ -148,7 +150,8 @@ async fn tick( let starting_connections = std::iter::repeat_with(|| { spawn_connection( ctx.clone(), - url.to_string(), + url.clone(), + headers.clone(), Duration::from_secs(*request_lifespan as u64), ) }) @@ -173,6 +176,7 @@ async fn tick( fn spawn_connection( ctx: StandaloneCtx, url: String, + headers: HashMap, request_lifespan: Duration, ) -> OutboundConnection { let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); @@ -181,7 +185,7 @@ fn spawn_connection( let draining2 = draining.clone(); let handle = tokio::spawn(async move { if let Err(err) = - outbound_handler(&ctx, url, request_lifespan, shutdown_rx, draining2).await + outbound_handler(&ctx, url, headers, request_lifespan, shutdown_rx, draining2).await { tracing::error!(?err, "outbound req failed"); @@ -206,12 +210,23 @@ fn spawn_connection( async fn outbound_handler( ctx: &StandaloneCtx, url: String, + headers: HashMap, request_lifespan: Duration, shutdown_rx: oneshot::Receiver<()>, draining: Arc, ) -> Result<()> { let client = rivet_pools::reqwest::client_no_timeout().await?; - let mut es = sse::EventSource::new(client.get(url))?; + let headers = headers + .into_iter() + .flat_map(|(k, v)| { + // NOTE: This will filter out invalid headers without warning + Some(( + k.parse::().ok()?, + v.parse::().ok()?, + )) + }) + .collect(); + let mut es = sse::EventSource::new(client.get(url).headers(headers))?; let mut runner_id = None; let stream_handler = async { diff --git a/packages/services/epoxy/src/ops/kv/get_optimistic.rs b/packages/services/epoxy/src/ops/kv/get_optimistic.rs index 726b54656f..4b0ba9cf1d 100644 --- a/packages/services/epoxy/src/ops/kv/get_optimistic.rs +++ b/packages/services/epoxy/src/ops/kv/get_optimistic.rs @@ -36,7 +36,7 @@ pub struct Output { /// /// We cannot use quorum reads for the fanout read because of the constraints of Epaxos. #[operation] -pub async fn epoxy_get_optimistic(ctx: &OperationCtx, input: &Input) -> Result { +pub async fn epoxy_kv_get_optimistic(ctx: &OperationCtx, input: &Input) -> Result { // Try to read locally let kv_key = keys::keys::KvValueKey::new(input.key.clone()); let cache_key = keys::keys::KvOptimisticCacheKey::new(input.key.clone()); diff --git a/sdks/schemas/data/namespace.runner_config.v1.bare b/sdks/schemas/data/namespace.runner_config.v1.bare index a25e630013..c2854f06a1 100644 --- a/sdks/schemas/data/namespace.runner_config.v1.bare +++ b/sdks/schemas/data/namespace.runner_config.v1.bare @@ -1,5 +1,6 @@ type Serverless struct { url: str + headers: map request_lifespan: u32 slots_per_runner: u32 min_runners: u32