diff --git a/Cargo.lock b/Cargo.lock index 06f0628d02..28af3d7aad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4048,6 +4048,7 @@ dependencies = [ "axum-test", "chrono", "gasoline", + "http 1.3.1", "hyper 0.14.32", "lazy_static", "opentelemetry", diff --git a/Cargo.toml b/Cargo.toml index 4fed60dcbf..bbd557c51a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ glob = "0.3.1" governor = "0.6" heck = "0.5" hex = "0.4" +http = "1.3.1" http-body = "1.0.0" http-body-util = "0.1.1" hyper-tls = "0.5.0" diff --git a/out/errors/api.bad_request.json b/out/errors/api.bad_request.json new file mode 100644 index 0000000000..54871df310 --- /dev/null +++ b/out/errors/api.bad_request.json @@ -0,0 +1,5 @@ +{ + "code": "bad_request", + "group": "api", + "message": "Request is invalid" +} \ No newline at end of file diff --git a/packages/common/api-builder/Cargo.toml b/packages/common/api-builder/Cargo.toml index 5655703737..b35bb4c8b8 100644 --- a/packages/common/api-builder/Cargo.toml +++ b/packages/common/api-builder/Cargo.toml @@ -11,6 +11,7 @@ axum.workspace = true axum-extra.workspace = true gas.workspace = true chrono.workspace = true +http.workspace = true hyper = { workspace = true, features = ["full"] } lazy_static.workspace = true opentelemetry.workspace = true diff --git a/packages/common/api-builder/src/errors.rs b/packages/common/api-builder/src/errors.rs index f8dd3e0993..4da070c1e2 100644 --- a/packages/common/api-builder/src/errors.rs +++ b/packages/common/api-builder/src/errors.rs @@ -1,4 +1,5 @@ use rivet_error::*; +use serde::Serialize; #[derive(RivetError)] #[error("api", "not_found", "The requested resource was not found")] @@ -19,3 +20,14 @@ pub struct ApiForbidden; #[derive(RivetError)] #[error("api", "internal_error", "An internal server error occurred")] pub struct ApiInternalError; + +#[derive(RivetError, Serialize)] +#[error( + "api", + "bad_request", + "Request is invalid", + "Request is invalid: {reason}" +)] +pub struct ApiBadRequest { + pub reason: String, +} diff --git a/packages/common/api-builder/src/extract.rs b/packages/common/api-builder/src/extract.rs new file mode 100644 index 0000000000..1275526ce6 --- /dev/null +++ b/packages/common/api-builder/src/extract.rs @@ -0,0 +1,109 @@ +use anyhow::anyhow; +use axum::{ + extract::{ + Request, + rejection::{ExtensionRejection, JsonRejection}, + {FromRequest, FromRequestParts}, + }, + response::IntoResponse, +}; +use axum_extra::extract::QueryRejection; +use http::request::Parts; +use serde::Serialize; + +use crate::{error_response::ApiError, errors::ApiBadRequest}; + +pub struct ExtractorError(ApiError); + +impl IntoResponse for ExtractorError { + fn into_response(self) -> axum::response::Response { + let mut res = self.0.into_response(); + + res.extensions_mut().insert(FailedExtraction); + + res + } +} + +#[derive(Clone, Copy)] +pub struct FailedExtraction; + +pub struct Json(pub T); + +impl FromRequest for Json +where + axum::extract::Json: FromRequest, + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request(req: Request, state: &S) -> Result { + axum::extract::Json::::from_request(req, state) + .await + .map(|json| Json(json.0)) + .map_err(|err| { + ExtractorError( + ApiBadRequest { + reason: err.body_text(), + } + .build() + .into(), + ) + }) + } +} + +impl IntoResponse for Json { + fn into_response(self) -> axum::response::Response { + let Self(value) = self; + axum::extract::Json(value).into_response() + } +} + +pub struct Query(pub T); + +impl FromRequestParts for Query +where + axum_extra::extract::Query: FromRequestParts, + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let res = axum_extra::extract::Query::::from_request_parts(parts, state) + .await + .map(|query| Query(query.0)) + .map_err(|err| { + ExtractorError( + ApiBadRequest { + reason: err.body_text(), + } + .build() + .into(), + ) + }); + + res + } +} + +pub struct Extension(pub T); + +impl FromRequestParts for Extension +where + axum::extract::Extension: FromRequestParts, + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + axum::extract::Extension::::from_request_parts(parts, state) + .await + .map(|ext| Extension(ext.0)) + .map_err(|err| { + ExtractorError( + anyhow!("developer error: extension error: {}", err.body_text()).into(), + ) + }) + } +} diff --git a/packages/common/api-builder/src/lib.rs b/packages/common/api-builder/src/lib.rs index fb5e04543e..c9d16e1f0c 100644 --- a/packages/common/api-builder/src/lib.rs +++ b/packages/common/api-builder/src/lib.rs @@ -1,6 +1,7 @@ pub mod context; pub mod error_response; pub mod errors; +pub mod extract; pub mod global_context; pub mod metrics; pub mod middleware; diff --git a/packages/common/api-builder/src/wrappers.rs b/packages/common/api-builder/src/wrappers.rs index 926b6ece48..7ce98b3669 100644 --- a/packages/common/api-builder/src/wrappers.rs +++ b/packages/common/api-builder/src/wrappers.rs @@ -1,18 +1,21 @@ use anyhow::Result; use axum::{ body::Bytes, - extract::{Extension, Path}, - response::{IntoResponse, Json}, + extract::Path, + response::IntoResponse, routing::{ delete as axum_delete, get as axum_get, patch as axum_patch, post as axum_post, put as axum_put, }, }; -use axum_extra::extract::Query; use serde::{Serialize, de::DeserializeOwned}; use std::future::Future; -use crate::{context::ApiCtx, error_response::ApiError}; +use crate::{ + context::ApiCtx, + error_response::ApiError, + extract::{Extension, Json, Query}, +}; /// Macro to generate wrapper functions for HTTP methods macro_rules! create_method_wrapper { diff --git a/packages/common/config/src/config/mod.rs b/packages/common/config/src/config/mod.rs index 667e05f2da..b1d8dd82f5 100644 --- a/packages/common/config/src/config/mod.rs +++ b/packages/common/config/src/config/mod.rs @@ -100,7 +100,7 @@ pub struct Root { impl Default for Root { fn default() -> Self { Root { - auth: None, + auth: Some(Auth::default()), guard: None, api_public: None, api_peer: None, diff --git a/packages/core/api-public/src/actors/create.rs b/packages/core/api-public/src/actors/create.rs index 3d13b0f631..77f313a09e 100644 --- a/packages/core/api-public/src/actors/create.rs +++ b/packages/core/api-public/src/actors/create.rs @@ -1,10 +1,12 @@ use anyhow::Result; use axum::{ - extract::{Extension, Query}, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_types::actors::create::{CreateRequest, CreateResponse}; use rivet_api_util::request_remote_datacenter; use serde::{Deserialize, Serialize}; diff --git a/packages/core/api-public/src/actors/delete.rs b/packages/core/api-public/src/actors/delete.rs index 3bba8d2c09..9cc24ab2ce 100644 --- a/packages/core/api-public/src/actors/delete.rs +++ b/packages/core/api-public/src/actors/delete.rs @@ -1,10 +1,13 @@ use anyhow::Result; use axum::{ - extract::{Extension, Path, Query}, + extract::Path, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_util::request_remote_datacenter_raw; use rivet_util::Id; use serde::{Deserialize, Serialize}; diff --git a/packages/core/api-public/src/actors/get_or_create.rs b/packages/core/api-public/src/actors/get_or_create.rs index 09690fbe8a..88faa1b4a7 100644 --- a/packages/core/api-public/src/actors/get_or_create.rs +++ b/packages/core/api-public/src/actors/get_or_create.rs @@ -1,10 +1,12 @@ use anyhow::Result; use axum::{ - extract::{Extension, Query}, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_types::actors::CrashPolicy; use rivet_util::Id; use serde::{Deserialize, Serialize}; diff --git a/packages/core/api-public/src/actors/list.rs b/packages/core/api-public/src/actors/list.rs index f35f2f696d..e6804fe54f 100644 --- a/packages/core/api-public/src/actors/list.rs +++ b/packages/core/api-public/src/actors/list.rs @@ -1,10 +1,12 @@ use anyhow::{Context, Result}; use axum::{ - extract::{Extension, Query}, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_types::pagination::Pagination; use rivet_api_util::fanout_to_datacenters; use serde::{Deserialize, Serialize}; diff --git a/packages/core/api-public/src/actors/list_names.rs b/packages/core/api-public/src/actors/list_names.rs index 440acbdb87..db6b2f2713 100644 --- a/packages/core/api-public/src/actors/list_names.rs +++ b/packages/core/api-public/src/actors/list_names.rs @@ -1,10 +1,12 @@ use anyhow::Result; use axum::{ - extract::{Extension, Query}, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_types::{actors::list_names::*, pagination::Pagination}; use rivet_api_util::fanout_to_datacenters; use rivet_types::actors::ActorName; diff --git a/packages/core/api-public/src/datacenters.rs b/packages/core/api-public/src/datacenters.rs index 5c70c78cec..00c578cb74 100644 --- a/packages/core/api-public/src/datacenters.rs +++ b/packages/core/api-public/src/datacenters.rs @@ -1,9 +1,6 @@ use anyhow::Result; -use axum::{ - extract::Extension, - response::{IntoResponse, Json, Response}, -}; -use rivet_api_builder::ApiError; +use axum::response::{IntoResponse, Json, Response}; +use rivet_api_builder::{ApiError, extract::Extension}; use rivet_api_types::{datacenters::list::*, pagination::Pagination}; use rivet_types::datacenters::Datacenter; diff --git a/packages/core/api-public/src/namespaces.rs b/packages/core/api-public/src/namespaces.rs index 7ade087ca7..a5b5f66353 100644 --- a/packages/core/api-public/src/namespaces.rs +++ b/packages/core/api-public/src/namespaces.rs @@ -1,10 +1,12 @@ use anyhow::Result; use axum::{ - extract::{Extension, Query}, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_peer::namespaces::*; use rivet_api_types::namespaces::list::*; use rivet_api_util::request_remote_datacenter; diff --git a/packages/core/api-public/src/router.rs b/packages/core/api-public/src/router.rs index fb009de915..59b6c77962 100644 --- a/packages/core/api-public/src/router.rs +++ b/packages/core/api-public/src/router.rs @@ -4,7 +4,7 @@ use axum::{ response::{Redirect, Response}, }; use reqwest::header::{AUTHORIZATION, HeaderMap}; -use rivet_api_builder::create_router; +use rivet_api_builder::{create_router, extract::FailedExtraction}; use utoipa::OpenApi; use crate::{actors, ctx, datacenters, namespaces, runner_configs, runners, ui}; @@ -116,7 +116,12 @@ async fn auth_middleware( let res = next.run(req).await; // Verify auth was handled - if !ctx.is_auth_handled() { + if res.extensions().get::().is_none() + && path != "/" + && path != "/ui" + && !path.starts_with("/ui/") + && !ctx.is_auth_handled() + { return Err(format!( "developer error: must explicitly handle auth in all endpoints (path: {path})" )); diff --git a/packages/core/api-public/src/runner_configs.rs b/packages/core/api-public/src/runner_configs.rs index 6aaaff06c6..0e9433b390 100644 --- a/packages/core/api-public/src/runner_configs.rs +++ b/packages/core/api-public/src/runner_configs.rs @@ -1,10 +1,13 @@ use anyhow::Result; use axum::{ - extract::{Extension, Path, Query}, + extract::Path, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_peer::runner_configs::*; use rivet_api_util::request_remote_datacenter; diff --git a/packages/core/api-public/src/runners.rs b/packages/core/api-public/src/runners.rs index a5b544f55a..b4a22454e2 100644 --- a/packages/core/api-public/src/runners.rs +++ b/packages/core/api-public/src/runners.rs @@ -1,10 +1,12 @@ use anyhow::Result; use axum::{ - extract::{Extension, Query}, http::HeaderMap, - response::{IntoResponse, Json, Response}, + response::{IntoResponse, Response}, +}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, }; -use rivet_api_builder::ApiError; use rivet_api_types::{pagination::Pagination, runners::list::*}; use rivet_api_util::fanout_to_datacenters; use serde::{Deserialize, Serialize};