Skip to content

Commit a2baad2

Browse files
authored
Merge pull request #9872 from eth3lbert/async-auth
auth: Convert `check()` and related blocking fns to async
2 parents d5e11d3 + 5fcdaf4 commit a2baad2

File tree

13 files changed

+165
-79
lines changed

13 files changed

+165
-79
lines changed

src/auth.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ use crate::middleware::log_request::RequestLogExt;
44
use crate::middleware::session::RequestSession;
55
use crate::models::token::{CrateScope, EndpointScope};
66
use crate::models::{ApiToken, User};
7-
use crate::util::diesel::Conn;
87
use crate::util::errors::{
98
account_locked, forbidden, internal, AppResult, InsecurelyGeneratedTokenRevoked,
109
};
1110
use crate::util::token::HashedToken;
1211
use chrono::Utc;
12+
use diesel_async::AsyncPgConnection;
1313
use http::header;
1414
use http::request::Parts;
1515

@@ -58,8 +58,12 @@ impl AuthCheck {
5858
}
5959

6060
#[instrument(name = "auth.check", skip_all)]
61-
pub fn check(&self, parts: &Parts, conn: &mut impl Conn) -> AppResult<Authentication> {
62-
let auth = authenticate(parts, conn)?;
61+
pub async fn check(
62+
&self,
63+
parts: &Parts,
64+
conn: &mut AsyncPgConnection,
65+
) -> AppResult<Authentication> {
66+
let auth = authenticate(parts, conn).await?;
6367

6468
if let Some(token) = auth.api_token() {
6569
if !self.allow_token {
@@ -168,9 +172,9 @@ impl Authentication {
168172
}
169173

170174
#[instrument(skip_all)]
171-
fn authenticate_via_cookie(
175+
async fn authenticate_via_cookie(
172176
parts: &Parts,
173-
conn: &mut impl Conn,
177+
conn: &mut AsyncPgConnection,
174178
) -> AppResult<Option<CookieAuthentication>> {
175179
let user_id_from_session = parts
176180
.session()
@@ -181,7 +185,7 @@ fn authenticate_via_cookie(
181185
return Ok(None);
182186
};
183187

184-
let user = User::find(conn, id).map_err(|err| {
188+
let user = User::async_find(conn, id).await.map_err(|err| {
185189
parts.request_log().add("cause", err);
186190
internal("user_id from cookie not found in database")
187191
})?;
@@ -194,9 +198,9 @@ fn authenticate_via_cookie(
194198
}
195199

196200
#[instrument(skip_all)]
197-
fn authenticate_via_token(
201+
async fn authenticate_via_token(
198202
parts: &Parts,
199-
conn: &mut impl Conn,
203+
conn: &mut AsyncPgConnection,
200204
) -> AppResult<Option<TokenAuthentication>> {
201205
let maybe_authorization = parts
202206
.headers()
@@ -210,14 +214,16 @@ fn authenticate_via_token(
210214
let token =
211215
HashedToken::parse(header_value).map_err(|_| InsecurelyGeneratedTokenRevoked::boxed())?;
212216

213-
let token = ApiToken::find_by_api_token(conn, &token).map_err(|e| {
214-
let cause = format!("invalid token caused by {e}");
215-
parts.request_log().add("cause", cause);
217+
let token = ApiToken::async_find_by_api_token(conn, &token)
218+
.await
219+
.map_err(|e| {
220+
let cause = format!("invalid token caused by {e}");
221+
parts.request_log().add("cause", cause);
216222

217-
forbidden("authentication failed")
218-
})?;
223+
forbidden("authentication failed")
224+
})?;
219225

220-
let user = User::find(conn, token.user_id).map_err(|err| {
226+
let user = User::async_find(conn, token.user_id).await.map_err(|err| {
221227
parts.request_log().add("cause", err);
222228
internal("user_id from token not found in database")
223229
})?;
@@ -231,16 +237,16 @@ fn authenticate_via_token(
231237
}
232238

233239
#[instrument(skip_all)]
234-
fn authenticate(parts: &Parts, conn: &mut impl Conn) -> AppResult<Authentication> {
240+
async fn authenticate(parts: &Parts, conn: &mut AsyncPgConnection) -> AppResult<Authentication> {
235241
controllers::util::verify_origin(parts)?;
236242

237-
match authenticate_via_cookie(parts, conn) {
243+
match authenticate_via_cookie(parts, conn).await {
238244
Ok(None) => {}
239245
Ok(Some(auth)) => return Ok(Authentication::Cookie(auth)),
240246
Err(err) => return Err(err),
241247
}
242248

243-
match authenticate_via_token(parts, conn) {
249+
match authenticate_via_token(parts, conn).await {
244250
Ok(None) => {}
245251
Ok(Some(auth)) => return Ok(Authentication::Token(auth)),
246252
Err(err) => return Err(err),

src/controllers/crate_owner_invitation.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ use tokio::runtime::Handle;
2727

2828
/// Handles the `GET /api/v1/me/crate_owner_invitations` route.
2929
pub async fn list(app: AppState, req: Parts) -> AppResult<Json<Value>> {
30-
let conn = app.db_read().await?;
30+
let mut conn = app.db_read().await?;
31+
let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?;
3132
spawn_blocking(move || {
3233
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
3334

34-
let auth = AuthCheck::only_cookie().check(&req, conn)?;
3535
let user_id = auth.user_id();
3636

3737
let PrivateListResponse {
@@ -69,12 +69,11 @@ pub async fn list(app: AppState, req: Parts) -> AppResult<Json<Value>> {
6969

7070
/// Handles the `GET /api/private/crate_owner_invitations` route.
7171
pub async fn private_list(app: AppState, req: Parts) -> AppResult<Json<PrivateListResponse>> {
72-
let conn = app.db_read().await?;
72+
let mut conn = app.db_read().await?;
73+
let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?;
7374
spawn_blocking(move || {
7475
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
7576

76-
let auth = AuthCheck::only_cookie().check(&req, conn)?;
77-
7877
let filter = if let Some(crate_name) = req.query().get("crate_name") {
7978
ListFilter::CrateName(crate_name.clone())
8079
} else if let Some(id) = req.query().get("invitee_id").and_then(|i| i.parse().ok()) {
@@ -284,11 +283,11 @@ pub async fn handle_invite(state: AppState, req: BytesRequest) -> AppResult<Json
284283

285284
let crate_invite = crate_invite.crate_owner_invite;
286285

287-
let conn = state.db_write().await?;
286+
let mut conn = state.db_write().await?;
287+
let auth = AuthCheck::default().check(&parts, &mut conn).await?;
288288
spawn_blocking(move || {
289289
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
290290

291-
let auth = AuthCheck::default().check(&parts, conn)?;
292291
let user_id = auth.user_id();
293292

294293
let config = &state.config;

src/controllers/krate/follow.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ pub async fn follow(
3434
Path(crate_name): Path<String>,
3535
req: Parts,
3636
) -> AppResult<Response> {
37-
let conn = app.db_write().await?;
37+
let mut conn = app.db_write().await?;
38+
let user_id = AuthCheck::default().check(&req, &mut conn).await?.user_id();
3839
spawn_blocking(move || {
3940
use diesel::RunQueryDsl;
4041

4142
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
4243

43-
let user_id = AuthCheck::default().check(&req, conn)?.user_id();
4444
let follow = follow_target(&crate_name, conn, user_id)?;
4545
diesel::insert_into(follows::table)
4646
.values(&follow)
@@ -58,13 +58,13 @@ pub async fn unfollow(
5858
Path(crate_name): Path<String>,
5959
req: Parts,
6060
) -> AppResult<Response> {
61-
let conn = app.db_write().await?;
61+
let mut conn = app.db_write().await?;
62+
let user_id = AuthCheck::default().check(&req, &mut conn).await?.user_id();
6263
spawn_blocking(move || {
6364
use diesel::RunQueryDsl;
6465

6566
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
6667

67-
let user_id = AuthCheck::default().check(&req, conn)?.user_id();
6868
let follow = follow_target(&crate_name, conn, user_id)?;
6969
diesel::delete(&follow).execute(conn)?;
7070

@@ -79,15 +79,18 @@ pub async fn following(
7979
Path(crate_name): Path<String>,
8080
req: Parts,
8181
) -> AppResult<Json<Value>> {
82-
let conn = app.db_read_prefer_primary().await?;
82+
let mut conn = app.db_read_prefer_primary().await?;
83+
let user_id = AuthCheck::only_cookie()
84+
.check(&req, &mut conn)
85+
.await?
86+
.user_id();
8387
spawn_blocking(move || {
8488
use diesel::RunQueryDsl;
8589

8690
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
8791

8892
use diesel::dsl::exists;
8993

90-
let user_id = AuthCheck::only_cookie().check(&req, conn)?.user_id();
9194
let follow = follow_target(&crate_name, conn, user_id)?;
9295
let following =
9396
diesel::select(exists(follows::table.find(follow.id()))).get_result::<bool>(conn)?;

src/controllers/krate/owners.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,17 @@ async fn modify_owners(
130130
));
131131
}
132132

133-
let conn = app.db_write().await?;
133+
let mut conn = app.db_write().await?;
134+
let auth = AuthCheck::default()
135+
.with_endpoint_scope(EndpointScope::ChangeOwners)
136+
.for_crate(&crate_name)
137+
.check(&parts, &mut conn)
138+
.await?;
134139
spawn_blocking(move || {
135140
use diesel::RunQueryDsl;
136141

137142
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
138143

139-
let auth = AuthCheck::default()
140-
.with_endpoint_scope(EndpointScope::ChangeOwners)
141-
.for_crate(&crate_name)
142-
.check(&parts, conn)?;
143-
144144
let user = auth.user();
145145

146146
// The set of emails to send out after invite processing is complete and

src/controllers/krate/publish.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,17 @@ pub async fn publish(app: AppState, req: BytesRequest) -> AppResult<Json<GoodCra
8181
request_log.add("crate_name", &*metadata.name);
8282
request_log.add("crate_version", &version_string);
8383

84-
let conn = app.db_write().await?;
85-
spawn_blocking(move || {
86-
use diesel::RunQueryDsl;
84+
let mut conn = app.db_write().await?;
8785

88-
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
86+
let (existing_crate, auth) = {
87+
use diesel_async::RunQueryDsl;
8988

9089
// this query should only be used for the endpoint scope calculation
9190
// since a race condition there would only cause `publish-new` instead of
9291
// `publish-update` to be used.
9392
let existing_crate: Option<Crate> = Crate::by_name(&metadata.name)
94-
.first::<Crate>(conn)
93+
.first::<Crate>(&mut conn)
94+
.await
9595
.optional()?;
9696

9797
let endpoint_scope = match existing_crate {
@@ -102,7 +102,15 @@ pub async fn publish(app: AppState, req: BytesRequest) -> AppResult<Json<GoodCra
102102
let auth = AuthCheck::default()
103103
.with_endpoint_scope(endpoint_scope)
104104
.for_crate(&metadata.name)
105-
.check(&req, conn)?;
105+
.check(&req, &mut conn)
106+
.await?;
107+
(existing_crate, auth)
108+
};
109+
110+
spawn_blocking(move || {
111+
use diesel::RunQueryDsl;
112+
113+
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
106114

107115
let api_token_id = auth.api_token_id();
108116
let user = auth.user();

src/controllers/krate/search.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ use axum::Json;
66
use diesel::dsl::{exists, sql, InnerJoinQuerySource, LeftJoinQuerySource};
77
use diesel::sql_types::{Array, Bool, Text};
88
use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
9+
use diesel_async::AsyncPgConnection;
910
use diesel_full_text_search::*;
1011
use http::request::Parts;
1112
use serde_json::Value;
1213
use std::cell::OnceCell;
14+
use tokio::runtime::Handle;
1315

1416
use crate::app::AppState;
1517
use crate::controllers::helpers::Paginate;
@@ -22,7 +24,6 @@ use crate::controllers::helpers::pagination::{Page, Paginated, PaginationOptions
2224
use crate::models::krate::ALL_COLUMNS;
2325
use crate::sql::{array_agg, canon_crate_name, lower};
2426
use crate::tasks::spawn_blocking;
25-
use crate::util::diesel::Conn;
2627
use crate::util::RequestUtils;
2728

2829
/// Handles the `GET /crates` route.
@@ -303,12 +304,14 @@ impl<'a> FilterParams<'a> {
303304
.as_deref()
304305
}
305306

306-
fn authed_user_id(&self, req: &Parts, conn: &mut impl Conn) -> AppResult<i32> {
307+
fn authed_user_id(&self, req: &Parts, conn: &mut AsyncPgConnection) -> AppResult<i32> {
307308
if let Some(val) = self._auth_user_id.get() {
308309
return Ok(*val);
309310
}
310311

311-
let user_id = AuthCheck::default().check(req, conn)?.user_id();
312+
let user_id = Handle::current()
313+
.block_on(AuthCheck::default().check(req, conn))?
314+
.user_id();
312315

313316
// This should not fail, because of the `get()` check above
314317
let _ = self._auth_user_id.set(user_id);
@@ -319,7 +322,7 @@ impl<'a> FilterParams<'a> {
319322
fn make_query(
320323
&'a self,
321324
req: &Parts,
322-
conn: &mut impl Conn,
325+
conn: &mut AsyncPgConnection,
323326
) -> AppResult<crates::BoxedQuery<'a, diesel::pg::Pg>> {
324327
let mut query = crates::table.into_boxed();
325328

src/controllers/token.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ pub async fn list(
4141
Query(params): Query<GetParams>,
4242
req: Parts,
4343
) -> AppResult<Json<Value>> {
44-
let conn = app.db_read_prefer_primary().await?;
44+
let mut conn = app.db_read_prefer_primary().await?;
45+
let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?;
4546
spawn_blocking(move || {
4647
use diesel::RunQueryDsl;
4748

4849
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
4950

50-
let auth = AuthCheck::only_cookie().check(&req, conn)?;
5151
let user = auth.user();
5252

5353
let tokens: Vec<ApiToken> = ApiToken::belonging_to(user)
@@ -92,13 +92,13 @@ pub async fn new(
9292
return Err(bad_request("name must have a value"));
9393
}
9494

95-
let conn = app.db_write().await?;
95+
let mut conn = app.db_write().await?;
96+
let auth = AuthCheck::default().check(&parts, &mut conn).await?;
9697
spawn_blocking(move || {
9798
use diesel::RunQueryDsl;
9899

99100
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
100101

101-
let auth = AuthCheck::default().check(&parts, conn)?;
102102
if auth.api_token_id().is_some() {
103103
return Err(bad_request(
104104
"cannot use an API token to create a new API token",
@@ -175,13 +175,13 @@ pub async fn new(
175175

176176
/// Handles the `GET /me/tokens/:id` route.
177177
pub async fn show(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult<Json<Value>> {
178-
let conn = app.db_write().await?;
178+
let mut conn = app.db_write().await?;
179+
let auth = AuthCheck::default().check(&req, &mut conn).await?;
179180
spawn_blocking(move || {
180181
use diesel::RunQueryDsl;
181182

182183
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
183184

184-
let auth = AuthCheck::default().check(&req, conn)?;
185185
let user = auth.user();
186186
let token = ApiToken::belonging_to(user)
187187
.find(id)
@@ -195,13 +195,13 @@ pub async fn show(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult<J
195195

196196
/// Handles the `DELETE /me/tokens/:id` route.
197197
pub async fn revoke(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult<Json<Value>> {
198-
let conn = app.db_write().await?;
198+
let mut conn = app.db_write().await?;
199+
let auth = AuthCheck::default().check(&req, &mut conn).await?;
199200
spawn_blocking(move || {
200201
use diesel::RunQueryDsl;
201202

202203
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
203204

204-
let auth = AuthCheck::default().check(&req, conn)?;
205205
let user = auth.user();
206206
diesel::update(ApiToken::belonging_to(user).find(id))
207207
.set(api_tokens::revoked.eq(true))
@@ -214,13 +214,13 @@ pub async fn revoke(app: AppState, Path(id): Path<i32>, req: Parts) -> AppResult
214214

215215
/// Handles the `DELETE /tokens/current` route.
216216
pub async fn revoke_current(app: AppState, req: Parts) -> AppResult<Response> {
217-
let conn = app.db_write().await?;
217+
let mut conn = app.db_write().await?;
218+
let auth = AuthCheck::default().check(&req, &mut conn).await?;
218219
spawn_blocking(move || {
219220
use diesel::RunQueryDsl;
220221

221222
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
222223

223-
let auth = AuthCheck::default().check(&req, conn)?;
224224
let api_token_id = auth
225225
.api_token_id()
226226
.ok_or_else(|| bad_request("token not provided"))?;

0 commit comments

Comments
 (0)