Skip to content

Commit 1a8a1e9

Browse files
committed
refactor: Clean up Outcome<->Result conversions
Stop defining a try_outcome! macro that shadows the same-named macro from rocket core. Provide a specialized version of `rocket::outcome::IntoOutcome` that works with `PointercrateError`s only, and use this to avoid manually specifying the status code when returning an Outcome::Error from a FromRequest or FromData impl. Signed-off-by: stadust <43299462+stadust@users.noreply.github.com>
1 parent 4969324 commit 1a8a1e9

File tree

3 files changed

+84
-67
lines changed

3 files changed

+84
-67
lines changed

pointercrate-core-api/src/error.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::response::Page;
22
use log::info;
33
use pointercrate_core::error::PointercrateError;
44
use pointercrate_core_pages::error::ErrorFragment;
5+
use rocket::outcome::Outcome;
56
use rocket::{
67
http::{MediaType, Status},
78
response::Responder,
@@ -60,3 +61,40 @@ impl<E: PointercrateError> From<E> for ErrorResponder {
6061
}
6162
}
6263
}
64+
65+
/// A version of [`IntoOutcome`](rocket::outcome::IntoOutcome) specially crafted for [`PointercrateError`]s
66+
pub trait IntoOutcome2<S, E> {
67+
fn into_outcome<F, E2: From<E>>(self) -> Outcome<S, (Status, E2), F>;
68+
}
69+
70+
impl<S, E: PointercrateError> IntoOutcome2<S, E> for std::result::Result<S, E> {
71+
fn into_outcome<F, E2: From<E>>(self) -> Outcome<S, (Status, E2), F> {
72+
self.map(Outcome::Success).unwrap_or_else(|e| e.into_outcome())
73+
}
74+
}
75+
76+
impl<S, E: PointercrateError> IntoOutcome2<S, E> for E {
77+
fn into_outcome<F, E2: From<E>>(self) -> Outcome<S, (Status, E2), F> {
78+
Outcome::Error((Status::new(self.status_code()), self.into()))
79+
}
80+
}
81+
82+
#[macro_export]
83+
macro_rules! tryo_result {
84+
($result: expr) => {
85+
rocket::outcome::try_outcome!($crate::error::IntoOutcome2::into_outcome($result))
86+
};
87+
}
88+
89+
#[macro_export]
90+
macro_rules! tryo_state {
91+
($request: expr, $typ: ty) => {
92+
$crate::tryo_result!($request
93+
.rocket()
94+
.state::<$typ>()
95+
.ok_or_else(|| pointercrate_core::error::CoreError::internal_server_error(format!(
96+
"Missing required state: '{}'",
97+
stringify!($typ)
98+
))))
99+
};
100+
}

pointercrate-core-api/src/etag.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::error::IntoOutcome2;
12
use crate::response::Response2;
23
use pointercrate_core::{error::CoreError, etag::Taggable};
34
use rocket::{
@@ -35,10 +36,12 @@ impl<'r> FromRequest<'r> for Precondition {
3536
type Error = CoreError;
3637

3738
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
38-
match request.headers().get_one("if-match") {
39-
Some(if_match) => Outcome::Success(Precondition(if_match.split(',').map(ToString::to_string).collect())),
40-
None => Outcome::Error((Status::PreconditionRequired, CoreError::PreconditionRequired)),
41-
}
39+
request
40+
.headers()
41+
.get_one("if-match")
42+
.map(|if_match| Precondition(if_match.split(',').map(ToString::to_string).collect()))
43+
.ok_or(CoreError::PreconditionFailed)
44+
.into_outcome()
4245
}
4346
}
4447

pointercrate-user-api/src/auth.rs

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
use base64::{engine::general_purpose::STANDARD, Engine};
22
use log::warn;
33
use pointercrate_core::{
4-
error::{CoreError, PointercrateError},
4+
error::CoreError,
55
permission::{Permission, PermissionsManager},
66
pool::{audit_connection, PointercratePool},
77
};
8+
use pointercrate_core_api::error::IntoOutcome2;
9+
use pointercrate_core_api::{tryo_result, tryo_state};
810
use pointercrate_user::{
911
auth::{AccessClaims, ApiToken, AuthenticatedUser, NonMutating, PasswordOrBrowser},
1012
error::UserError,
1113
};
1214
use rocket::{
1315
http::{Method, Status},
1416
request::{FromRequest, Outcome},
15-
Request, State,
17+
Request,
1618
};
1719
use sqlx::{Postgres, Transaction};
1820
use std::collections::HashSet;
@@ -44,29 +46,6 @@ impl<A> Auth<A> {
4446
}
4547
}
4648

47-
macro_rules! try_outcome {
48-
($outcome:expr) => {
49-
match $outcome {
50-
Ok(success) => success,
51-
Err(error) => return Outcome::Error((Status::from_code(error.status_code()).unwrap(), error.into())),
52-
}
53-
};
54-
}
55-
56-
macro_rules! try_state {
57-
($request: expr, $typ: ty) => {
58-
match $request.guard::<&State<$typ>>().await {
59-
Outcome::Success(state) => state.inner(),
60-
_ => {
61-
return Outcome::Error((
62-
Status::InternalServerError,
63-
CoreError::internal_server_error(format!("Missing required state: '{}'", stringify!($typ))).into(),
64-
))
65-
},
66-
}
67-
};
68-
}
69-
7049
#[rocket::async_trait]
7150
impl<'r> FromRequest<'r> for Auth<NonMutating> {
7251
type Error = UserError;
@@ -76,17 +55,17 @@ impl<'r> FromRequest<'r> for Auth<NonMutating> {
7655
return Outcome::Forward(Status::NotFound);
7756
}
7857

79-
let pool = try_state!(request, PointercratePool);
80-
let permission_manager = try_state!(request, PermissionsManager).clone();
58+
let pool = tryo_state!(request, PointercratePool);
59+
let permission_manager = tryo_state!(request, PermissionsManager).clone();
8160

82-
let mut connection = try_outcome!(pool.transaction().await);
61+
let mut connection = tryo_result!(pool.transaction().await);
8362

8463
if let Some(access_token) = request.cookies().get("access_token") {
85-
let access_claims = try_outcome!(AccessClaims::decode(access_token.value()));
86-
let user = try_outcome!(AuthenticatedUser::by_id(try_outcome!(access_claims.id()), &mut connection).await);
87-
let authenticated_for_get = try_outcome!(user.validate_cookie_claims(access_claims));
64+
let access_claims = tryo_result!(AccessClaims::decode(access_token.value()));
65+
let user = tryo_result!(AuthenticatedUser::by_id(tryo_result!(access_claims.id()), &mut connection).await);
66+
let authenticated_for_get = tryo_result!(user.validate_cookie_claims(access_claims));
8867

89-
try_outcome!(audit_connection(&mut connection, authenticated_for_get.user().id).await);
68+
tryo_result!(audit_connection(&mut connection, authenticated_for_get.user().id).await);
9069

9170
return Outcome::Success(Auth {
9271
user: authenticated_for_get,
@@ -95,7 +74,7 @@ impl<'r> FromRequest<'r> for Auth<NonMutating> {
9574
});
9675
}
9776

98-
Outcome::Error((Status::Unauthorized, CoreError::Unauthorized.into()))
77+
CoreError::Unauthorized.into_outcome()
9978
}
10079
}
10180

@@ -109,18 +88,18 @@ impl<'r> FromRequest<'r> for Auth<ApiToken> {
10988
return Outcome::Forward(Status::NotFound);
11089
}
11190

112-
let pool = try_state!(request, PointercratePool);
113-
let permission_manager = try_state!(request, PermissionsManager).clone();
91+
let pool = tryo_state!(request, PointercratePool);
92+
let permission_manager = tryo_state!(request, PermissionsManager).clone();
11493

115-
let mut connection = try_outcome!(pool.transaction().await);
94+
let mut connection = tryo_result!(pool.transaction().await);
11695

11796
for authorization in request.headers().get("Authorization") {
11897
if let ["Bearer", token] = authorization.split(' ').collect::<Vec<_>>()[..] {
119-
let access_claims = try_outcome!(AccessClaims::decode(token));
120-
let user = try_outcome!(AuthenticatedUser::by_id(try_outcome!(access_claims.id()), &mut connection).await);
121-
let authenticated_user = try_outcome!(user.validate_api_access(access_claims));
98+
let access_claims = tryo_result!(AccessClaims::decode(token));
99+
let user = tryo_result!(AuthenticatedUser::by_id(tryo_result!(access_claims.id()), &mut connection).await);
100+
let authenticated_user = tryo_result!(user.validate_api_access(access_claims));
122101

123-
try_outcome!(audit_connection(&mut connection, authenticated_user.user().id).await);
102+
tryo_result!(audit_connection(&mut connection, authenticated_user.user().id).await);
124103

125104
return Outcome::Success(Auth {
126105
user: authenticated_user,
@@ -132,12 +111,12 @@ impl<'r> FromRequest<'r> for Auth<ApiToken> {
132111

133112
// no matching auth header, lets try the cookie
134113
if let (Some(access_token), Some(csrf_token)) = (request.cookies().get("access_token"), request.headers().get_one("X-CSRF-TOKEN")) {
135-
let access_claims = try_outcome!(AccessClaims::decode(access_token.value()));
136-
let user = try_outcome!(AuthenticatedUser::by_id(try_outcome!(access_claims.id()), &mut connection).await);
137-
let authenticated_for_get = try_outcome!(user.validate_cookie_claims(access_claims));
138-
let authenticated = try_outcome!(authenticated_for_get.validate_csrf_token(csrf_token));
114+
let access_claims = tryo_result!(AccessClaims::decode(access_token.value()));
115+
let user = tryo_result!(AuthenticatedUser::by_id(tryo_result!(access_claims.id()), &mut connection).await);
116+
let authenticated_for_get = tryo_result!(user.validate_cookie_claims(access_claims));
117+
let authenticated = tryo_result!(authenticated_for_get.validate_csrf_token(csrf_token));
139118

140-
try_outcome!(audit_connection(&mut connection, authenticated.user().id).await);
119+
tryo_result!(audit_connection(&mut connection, authenticated.user().id).await);
141120

142121
return Outcome::Success(Auth {
143122
user: authenticated.downgrade_auth_type().unwrap(), // cannot fail: we are not password authenticated
@@ -146,7 +125,7 @@ impl<'r> FromRequest<'r> for Auth<ApiToken> {
146125
});
147126
}
148127

149-
Outcome::Error((Status::Unauthorized, CoreError::Unauthorized.into()))
128+
CoreError::Unauthorized.into_outcome()
150129
}
151130
}
152131

@@ -156,25 +135,22 @@ impl<'r> FromRequest<'r> for Auth<PasswordOrBrowser> {
156135

157136
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
158137
if request.method() == Method::Get {
159-
return Outcome::Error((
160-
Status::InternalServerError,
161-
CoreError::internal_server_error("Requiring higher authentication on a GET request. This is nonsense").into(),
162-
));
138+
return CoreError::internal_server_error("Requiring higher authentication on a GET request. This is nonsense").into_outcome();
163139
}
164140

165141
// No auth header set, forward to the request handler that doesnt require authorization (if one exists)
166142
if request.headers().get_one("Authorization").is_none() && request.cookies().get("access_token").is_none() {
167143
return Outcome::Forward(Status::NotFound);
168144
}
169145

170-
let pool = try_state!(request, PointercratePool);
171-
let permission_manager = try_state!(request, PermissionsManager).clone();
146+
let pool = tryo_state!(request, PointercratePool);
147+
let permission_manager = tryo_state!(request, PermissionsManager).clone();
172148

173-
let mut connection = try_outcome!(pool.transaction().await);
149+
let mut connection = tryo_result!(pool.transaction().await);
174150

175151
for authorization in request.headers().get("Authorization") {
176152
if let ["Basic", basic_auth] = authorization.split(' ').collect::<Vec<_>>()[..] {
177-
let decoded = try_outcome!(STANDARD
153+
let decoded = tryo_result!(STANDARD
178154
.decode(basic_auth)
179155
.map_err(|_| ())
180156
.and_then(|bytes| String::from_utf8(bytes).map_err(|_| ()))
@@ -185,10 +161,10 @@ impl<'r> FromRequest<'r> for Auth<PasswordOrBrowser> {
185161
}));
186162

187163
if let [username, password] = &decoded.splitn(2, ':').collect::<Vec<_>>()[..] {
188-
let user = try_outcome!(AuthenticatedUser::by_name(username, &mut connection).await);
189-
let authenticated = try_outcome!(user.verify_password(password));
164+
let user = tryo_result!(AuthenticatedUser::by_name(username, &mut connection).await);
165+
let authenticated = tryo_result!(user.verify_password(password));
190166

191-
try_outcome!(audit_connection(&mut connection, authenticated.user().id).await);
167+
tryo_result!(audit_connection(&mut connection, authenticated.user().id).await);
192168

193169
return Outcome::Success(Auth {
194170
user: authenticated,
@@ -200,12 +176,12 @@ impl<'r> FromRequest<'r> for Auth<PasswordOrBrowser> {
200176
}
201177
// no matching auth header, lets try the cookie
202178
if let (Some(access_token), Some(csrf_token)) = (request.cookies().get("access_token"), request.headers().get_one("X-CSRF-TOKEN")) {
203-
let access_claims = try_outcome!(AccessClaims::decode(access_token.value()));
204-
let user = try_outcome!(AuthenticatedUser::by_id(try_outcome!(access_claims.id()), &mut connection).await);
205-
let authenticated_for_get = try_outcome!(user.validate_cookie_claims(access_claims));
206-
let authenticated = try_outcome!(authenticated_for_get.validate_csrf_token(csrf_token));
179+
let access_claims = tryo_result!(AccessClaims::decode(access_token.value()));
180+
let user = tryo_result!(AuthenticatedUser::by_id(tryo_result!(access_claims.id()), &mut connection).await);
181+
let authenticated_for_get = tryo_result!(user.validate_cookie_claims(access_claims));
182+
let authenticated = tryo_result!(authenticated_for_get.validate_csrf_token(csrf_token));
207183

208-
try_outcome!(audit_connection(&mut connection, authenticated.user().id).await);
184+
tryo_result!(audit_connection(&mut connection, authenticated.user().id).await);
209185

210186
return Outcome::Success(Auth {
211187
user: authenticated,
@@ -214,6 +190,6 @@ impl<'r> FromRequest<'r> for Auth<PasswordOrBrowser> {
214190
});
215191
}
216192

217-
Outcome::Error((Status::Unauthorized, CoreError::Unauthorized.into()))
193+
CoreError::Unauthorized.into_outcome()
218194
}
219195
}

0 commit comments

Comments
 (0)