Skip to content

Commit 92f2de2

Browse files
MasterPtatoNathanFlurry
authored andcommitted
fix(api): fix unhandled auth handling, extractor errors
1 parent 7b241d8 commit 92f2de2

File tree

19 files changed

+189
-36
lines changed

19 files changed

+189
-36
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ glob = "0.3.1"
3232
governor = "0.6"
3333
heck = "0.5"
3434
hex = "0.4"
35+
http = "1.3.1"
3536
http-body = "1.0.0"
3637
http-body-util = "0.1.1"
3738
hyper-tls = "0.5.0"

out/errors/api.bad_request.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/common/api-builder/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ axum.workspace = true
1111
axum-extra.workspace = true
1212
gas.workspace = true
1313
chrono.workspace = true
14+
http.workspace = true
1415
hyper = { workspace = true, features = ["full"] }
1516
lazy_static.workspace = true
1617
opentelemetry.workspace = true

packages/common/api-builder/src/errors.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use rivet_error::*;
2+
use serde::Serialize;
23

34
#[derive(RivetError)]
45
#[error("api", "not_found", "The requested resource was not found")]
@@ -19,3 +20,14 @@ pub struct ApiForbidden;
1920
#[derive(RivetError)]
2021
#[error("api", "internal_error", "An internal server error occurred")]
2122
pub struct ApiInternalError;
23+
24+
#[derive(RivetError, Serialize)]
25+
#[error(
26+
"api",
27+
"bad_request",
28+
"Request is invalid",
29+
"Request is invalid: {reason}"
30+
)]
31+
pub struct ApiBadRequest {
32+
pub reason: String,
33+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use anyhow::anyhow;
2+
use axum::{
3+
extract::{
4+
Request,
5+
rejection::{ExtensionRejection, JsonRejection},
6+
{FromRequest, FromRequestParts},
7+
},
8+
response::IntoResponse,
9+
};
10+
use axum_extra::extract::QueryRejection;
11+
use http::request::Parts;
12+
use serde::Serialize;
13+
14+
use crate::{error_response::ApiError, errors::ApiBadRequest};
15+
16+
pub struct ExtractorError(ApiError);
17+
18+
impl IntoResponse for ExtractorError {
19+
fn into_response(self) -> axum::response::Response {
20+
let mut res = self.0.into_response();
21+
22+
res.extensions_mut().insert(FailedExtraction);
23+
24+
res
25+
}
26+
}
27+
28+
#[derive(Clone, Copy)]
29+
pub struct FailedExtraction;
30+
31+
pub struct Json<T>(pub T);
32+
33+
impl<S, T> FromRequest<S> for Json<T>
34+
where
35+
axum::extract::Json<T>: FromRequest<S, Rejection = JsonRejection>,
36+
S: Send + Sync,
37+
{
38+
type Rejection = ExtractorError;
39+
40+
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
41+
axum::extract::Json::<T>::from_request(req, state)
42+
.await
43+
.map(|json| Json(json.0))
44+
.map_err(|err| {
45+
ExtractorError(
46+
ApiBadRequest {
47+
reason: err.body_text(),
48+
}
49+
.build()
50+
.into(),
51+
)
52+
})
53+
}
54+
}
55+
56+
impl<T: Serialize> IntoResponse for Json<T> {
57+
fn into_response(self) -> axum::response::Response {
58+
let Self(value) = self;
59+
axum::extract::Json(value).into_response()
60+
}
61+
}
62+
63+
pub struct Query<T>(pub T);
64+
65+
impl<S, T> FromRequestParts<S> for Query<T>
66+
where
67+
axum_extra::extract::Query<T>: FromRequestParts<S, Rejection = QueryRejection>,
68+
S: Send + Sync,
69+
{
70+
type Rejection = ExtractorError;
71+
72+
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
73+
let res = axum_extra::extract::Query::<T>::from_request_parts(parts, state)
74+
.await
75+
.map(|query| Query(query.0))
76+
.map_err(|err| {
77+
ExtractorError(
78+
ApiBadRequest {
79+
reason: err.body_text(),
80+
}
81+
.build()
82+
.into(),
83+
)
84+
});
85+
86+
res
87+
}
88+
}
89+
90+
pub struct Extension<T>(pub T);
91+
92+
impl<S, T> FromRequestParts<S> for Extension<T>
93+
where
94+
axum::extract::Extension<T>: FromRequestParts<S, Rejection = ExtensionRejection>,
95+
S: Send + Sync,
96+
{
97+
type Rejection = ExtractorError;
98+
99+
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
100+
axum::extract::Extension::<T>::from_request_parts(parts, state)
101+
.await
102+
.map(|ext| Extension(ext.0))
103+
.map_err(|err| {
104+
ExtractorError(
105+
anyhow!("developer error: extension error: {}", err.body_text()).into(),
106+
)
107+
})
108+
}
109+
}

packages/common/api-builder/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub mod context;
22
pub mod error_response;
33
pub mod errors;
4+
pub mod extract;
45
pub mod global_context;
56
pub mod metrics;
67
pub mod middleware;

packages/common/api-builder/src/wrappers.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
use anyhow::Result;
22
use axum::{
33
body::Bytes,
4-
extract::{Extension, Path},
5-
response::{IntoResponse, Json},
4+
extract::Path,
5+
response::IntoResponse,
66
routing::{
77
delete as axum_delete, get as axum_get, patch as axum_patch, post as axum_post,
88
put as axum_put,
99
},
1010
};
11-
use axum_extra::extract::Query;
1211
use serde::{Serialize, de::DeserializeOwned};
1312
use std::future::Future;
1413

15-
use crate::{context::ApiCtx, error_response::ApiError};
14+
use crate::{
15+
context::ApiCtx,
16+
error_response::ApiError,
17+
extract::{Extension, Json, Query},
18+
};
1619

1720
/// Macro to generate wrapper functions for HTTP methods
1821
macro_rules! create_method_wrapper {

packages/common/config/src/config/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ pub struct Root {
100100
impl Default for Root {
101101
fn default() -> Self {
102102
Root {
103-
auth: None,
103+
auth: Some(Auth::default()),
104104
guard: None,
105105
api_public: None,
106106
api_peer: None,

packages/core/api-public/src/actors/create.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use anyhow::Result;
22
use axum::{
3-
extract::{Extension, Query},
43
http::HeaderMap,
5-
response::{IntoResponse, Json, Response},
4+
response::{IntoResponse, Response},
5+
};
6+
use rivet_api_builder::{
7+
ApiError,
8+
extract::{Extension, Json, Query},
69
};
7-
use rivet_api_builder::ApiError;
810
use rivet_api_types::actors::create::{CreateRequest, CreateResponse};
911
use rivet_api_util::request_remote_datacenter;
1012
use serde::{Deserialize, Serialize};

0 commit comments

Comments
 (0)