Skip to content

Commit d536aa4

Browse files
authored
refactor: simplify by getting rid of the unneeded User wrapper (#1722)
* refactor: simplify by getting rid of the unneeded User wrapper * refactor: use async_trait from axum
1 parent d155595 commit d536aa4

File tree

5 files changed

+44
-61
lines changed

5 files changed

+44
-61
lines changed

common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ wiremock = { workspace = true, optional = true }
4444
[features]
4545
axum = ["dep:axum"]
4646
claims = [
47+
"axum",
4748
"bytes",
4849
"chrono/clock",
4950
"headers",

common/src/claims.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ use std::{
55
task::{Context, Poll},
66
};
77

8+
use axum::extract::FromRequestParts;
89
use bytes::Bytes;
910
use chrono::{Duration, Utc};
1011
use headers::{Authorization, HeaderMapExt};
11-
use http::{Request, StatusCode};
12+
use http::{request::Parts, Request, StatusCode};
1213
use http_body::combinators::UnsyncBoxBody;
1314
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
1415
use opentelemetry::global;
@@ -332,6 +333,26 @@ impl Claim {
332333
}
333334
}
334335

336+
/// Extract the claim from the request and fail with unauthorized if the claim doesn't exist
337+
#[axum::async_trait]
338+
impl<S> FromRequestParts<S> for Claim {
339+
type Rejection = StatusCode;
340+
341+
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
342+
let claim = parts
343+
.extensions
344+
.get::<Claim>()
345+
.ok_or(StatusCode::UNAUTHORIZED)?;
346+
347+
// Record current account name for tracing purposes
348+
Span::current().record("account.user_id", &claim.sub);
349+
350+
trace!(?claim, "got user");
351+
352+
Ok(claim.clone())
353+
}
354+
}
355+
335356
// Future for layers that just return the inner response
336357
#[pin_project]
337358
pub struct ResponseFuture<F>(#[pin] pub F);

gateway/src/api/latest.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use shuttle_backends::metrics::{Metrics, TraceLayer};
2525
use shuttle_backends::project_name::ProjectName;
2626
use shuttle_backends::request_span;
2727
use shuttle_backends::ClaimExt;
28-
use shuttle_common::claims::{Scope, EXP_MINUTES};
28+
use shuttle_common::claims::{Claim, Scope, EXP_MINUTES};
2929
use shuttle_common::models::error::ErrorKind;
3030
use shuttle_common::models::service;
3131
use shuttle_common::models::{admin::ProjectResponse, project, stats};
@@ -47,7 +47,7 @@ use x509_parser::time::ASN1Time;
4747

4848
use crate::acme::{AccountWrapper, AcmeClient, CustomDomain};
4949
use crate::api::tracing::project_name_tracing_layer;
50-
use crate::auth::{ScopedUser, User};
50+
use crate::auth::ScopedUser;
5151
use crate::service::{ContainerSettings, GatewayService};
5252
use crate::task::{self, BoxedTask};
5353
use crate::tls::{GatewayCertResolver, RENEWAL_VALIDITY_THRESHOLD_IN_DAYS};
@@ -131,12 +131,12 @@ async fn check_project_name(
131131
}
132132
async fn get_projects_list(
133133
State(RouterState { service, .. }): State<RouterState>,
134-
User { id, .. }: User,
134+
Claim { sub, .. }: Claim,
135135
) -> Result<AxumJson<Vec<project::Response>>, Error> {
136136
let mut projects = vec![];
137137
for p in service
138138
.permit_client
139-
.get_user_projects(&id)
139+
.get_user_projects(&sub)
140140
.await
141141
.map_err(|_| Error::from(ErrorKind::Internal))?
142142
{
@@ -163,7 +163,7 @@ async fn create_project(
163163
State(RouterState {
164164
service, sender, ..
165165
}): State<RouterState>,
166-
User { id, claim, .. }: User,
166+
claim: Claim,
167167
CustomErrorPath(project_name): CustomErrorPath<ProjectName>,
168168
AxumJson(config): AxumJson<project::Config>,
169169
) -> Result<AxumJson<project::Response>, Error> {
@@ -172,7 +172,7 @@ async fn create_project(
172172
// Check that the user is within their project limits.
173173
let can_create_project = claim.can_create_project(
174174
service
175-
.get_project_count(&id)
175+
.get_project_count(&claim.sub)
176176
.await?
177177
.saturating_sub(is_cch_project as u32),
178178
);
@@ -184,7 +184,7 @@ async fn create_project(
184184
let project = service
185185
.create_project(
186186
project_name.clone(),
187-
&id,
187+
&claim.sub,
188188
claim.is_admin(),
189189
can_create_project,
190190
if is_cch_project {
@@ -398,7 +398,7 @@ async fn override_create_service(
398398
scoped_user: ScopedUser,
399399
req: Request<Body>,
400400
) -> Result<Response<Body>, Error> {
401-
let user_id = scoped_user.user.id.clone();
401+
let user_id = scoped_user.claim.sub.clone();
402402
let posthog_client = state.posthog_client.clone();
403403
tokio::spawn(async move {
404404
let event = async_posthog::Event::new("shuttle_api_start_deployment", &user_id);
@@ -460,9 +460,9 @@ async fn route_project(
460460
let project_name = scoped_user.scope;
461461
let is_cch_project = project_name.is_cch_project();
462462

463-
if !scoped_user.user.claim.is_admin() {
463+
if !scoped_user.claim.is_admin() {
464464
service
465-
.has_capacity(is_cch_project, &scoped_user.user.claim.tier)
465+
.has_capacity(is_cch_project, &scoped_user.claim.tier)
466466
.await?;
467467
}
468468

@@ -471,7 +471,7 @@ async fn route_project(
471471
.await?
472472
.0;
473473
service
474-
.route(&project.state, &project_name, &scoped_user.user.id, req)
474+
.route(&project.state, &project_name, &scoped_user.claim.sub, req)
475475
.await
476476
}
477477

gateway/src/api/project_caller.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl ProjectCaller {
4343
Ok(Self {
4444
project: project.state,
4545
project_name,
46-
user_id: scoped_user.user.id,
46+
user_id: scoped_user.claim.sub,
4747
service,
4848
headers: headers.clone(),
4949
})

gateway/src/auth.rs

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,22 @@
1-
use std::fmt::Debug;
2-
31
use axum::extract::{FromRef, FromRequestParts, Path};
42
use axum::http::request::Parts;
5-
use serde::{Deserialize, Serialize};
63
use shuttle_backends::project_name::ProjectName;
74
use shuttle_backends::ClaimExt;
85
use shuttle_common::claims::Claim;
96
use shuttle_common::models::error::InvalidProjectName;
10-
use shuttle_common::models::user::UserId;
11-
use tracing::{error, trace, Span};
7+
use tracing::error;
128

139
use crate::api::latest::RouterState;
1410
use crate::{Error, ErrorKind};
1511

16-
/// A wrapper to enrich a token with user details
17-
///
18-
/// The `FromRequest` impl consumes the API claim and enriches it with project
19-
/// details. Generally you want to use [`ScopedUser`] instead to ensure the request
20-
/// is valid against the user's owned resources.
21-
#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
22-
pub struct User {
23-
pub claim: Claim,
24-
pub id: UserId,
25-
}
26-
27-
#[async_trait]
28-
impl<S> FromRequestParts<S> for User
29-
where
30-
S: Send + Sync,
31-
RouterState: FromRef<S>,
32-
{
33-
type Rejection = Error;
34-
35-
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
36-
let claim = parts.extensions.get::<Claim>().ok_or(ErrorKind::Internal)?;
37-
let user_id = claim.sub.clone();
38-
39-
// Record current account name for tracing purposes
40-
Span::current().record("account.user_id", &user_id);
41-
42-
let user = User {
43-
claim: claim.clone(),
44-
id: user_id,
45-
};
46-
47-
trace!(?user, "got user");
48-
49-
Ok(user)
50-
}
51-
}
52-
5312
/// A wrapper for a guard that validates a user's API token *and*
5413
/// scopes the request to a project they own.
5514
///
5615
/// It is guaranteed that [`ScopedUser::scope`] exists and is owned
5716
/// by [`ScopedUser::name`].
5817
#[derive(Clone)]
5918
pub struct ScopedUser {
60-
pub user: User,
19+
pub claim: Claim,
6120
pub scope: ProjectName,
6221
}
6322

@@ -70,7 +29,9 @@ where
7029
type Rejection = Error;
7130

7231
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
73-
let user = User::from_request_parts(parts, state).await?;
32+
let claim = Claim::from_request_parts(parts, state)
33+
.await
34+
.map_err(|_| ErrorKind::Unauthorized)?;
7435

7536
let scope = match Path::<ProjectName>::from_request_parts(parts, state).await {
7637
Ok(Path(p)) => p,
@@ -82,12 +43,12 @@ where
8243

8344
let RouterState { service, .. } = RouterState::from_ref(state);
8445

85-
let allowed = user.claim.is_admin()
86-
|| user.claim.is_deployer()
46+
let allowed = claim.is_admin()
47+
|| claim.is_deployer()
8748
|| service
8849
.permit_client
8950
.allowed(
90-
&user.id,
51+
&claim.sub,
9152
&service.find_project_by_name(&scope).await?.id,
9253
"develop", // TODO: make this configurable per endpoint?
9354
)
@@ -98,7 +59,7 @@ where
9859
})?;
9960

10061
if allowed {
101-
Ok(Self { user, scope })
62+
Ok(Self { claim, scope })
10263
} else {
10364
Err(Error::from(ErrorKind::ProjectNotFound(scope.to_string())))
10465
}

0 commit comments

Comments
 (0)