Skip to content

Commit 3640ab0

Browse files
authored
Merge pull request #9972 from Turbo87/async-session
controllers/user/session: Remove `spawn_blocking()` usage
2 parents b47ff50 + 09bf303 commit 3640ab0

File tree

4 files changed

+155
-152
lines changed

4 files changed

+155
-152
lines changed

src/controllers/user/session.rs

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,19 @@ use axum::extract::{FromRequestParts, Query};
22
use axum::Json;
33
use axum_extra::json;
44
use axum_extra::response::ErasedJson;
5-
use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
5+
use diesel::prelude::*;
6+
use diesel_async::{AsyncPgConnection, RunQueryDsl};
67
use http::request::Parts;
7-
use oauth2::reqwest::http_client;
8+
use oauth2::reqwest::async_http_client;
89
use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse};
9-
use tokio::runtime::Handle;
1010

1111
use crate::app::AppState;
1212
use crate::email::Emails;
1313
use crate::middleware::log_request::RequestLogExt;
1414
use crate::middleware::session::SessionExtension;
1515
use crate::models::{NewUser, User};
1616
use crate::schema::users;
17-
use crate::tasks::spawn_blocking;
18-
use crate::util::diesel::{is_read_only_error, Conn};
17+
use crate::util::diesel::is_read_only_error;
1918
use crate::util::errors::{bad_request, server_error, AppResult};
2019
use crate::views::EncodableMe;
2120
use crates_io_github::GithubUser;
@@ -89,76 +88,74 @@ pub async fn authorize(
8988
session: SessionExtension,
9089
req: Parts,
9190
) -> AppResult<Json<EncodableMe>> {
92-
let app_clone = app.clone();
93-
let request_log = req.request_log().clone();
94-
95-
let conn = app.db_write().await?;
96-
spawn_blocking(move || {
97-
let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into();
98-
99-
// Make sure that the state we just got matches the session state that we
100-
// should have issued earlier.
101-
let session_state = session.remove("github_oauth_state").map(CsrfToken::new);
102-
if !session_state.is_some_and(|state| query.state.secret() == state.secret()) {
103-
return Err(bad_request("invalid state parameter"));
104-
}
91+
// Make sure that the state we just got matches the session state that we
92+
// should have issued earlier.
93+
let session_state = session.remove("github_oauth_state").map(CsrfToken::new);
94+
if !session_state.is_some_and(|state| query.state.secret() == state.secret()) {
95+
return Err(bad_request("invalid state parameter"));
96+
}
10597

106-
// Fetch the access token from GitHub using the code we just got
107-
let token = app
108-
.github_oauth
109-
.exchange_code(query.code)
110-
.request(http_client)
111-
.map_err(|err| {
112-
request_log.add("cause", err);
113-
server_error("Error obtaining token")
114-
})?;
98+
// Fetch the access token from GitHub using the code we just got
99+
let token = app
100+
.github_oauth
101+
.exchange_code(query.code)
102+
.request_async(async_http_client)
103+
.await
104+
.map_err(|err| {
105+
req.request_log().add("cause", err);
106+
server_error("Error obtaining token")
107+
})?;
115108

116-
let token = token.access_token();
109+
let token = token.access_token();
117110

118-
// Fetch the user info from GitHub using the access token we just got and create a user record
119-
let ghuser = Handle::current().block_on(app.github.current_user(token))?;
120-
let user = save_user_to_database(&ghuser, token.secret(), &app.emails, conn)?;
111+
// Fetch the user info from GitHub using the access token we just got and create a user record
112+
let ghuser = app.github.current_user(token).await?;
121113

122-
// Log in by setting a cookie and the middleware authentication
123-
session.insert("user_id".to_string(), user.id.to_string());
114+
let mut conn = app.db_write().await?;
115+
let user = save_user_to_database(&ghuser, token.secret(), &app.emails, &mut conn).await?;
124116

125-
Ok(())
126-
})
127-
.await?;
117+
// Log in by setting a cookie and the middleware authentication
118+
session.insert("user_id".to_string(), user.id.to_string());
128119

129-
super::me::me(app_clone, req).await
120+
super::me::me(app, req).await
130121
}
131122

132-
fn save_user_to_database(
123+
async fn save_user_to_database(
133124
user: &GithubUser,
134125
access_token: &str,
135126
emails: &Emails,
136-
conn: &mut impl Conn,
127+
conn: &mut AsyncPgConnection,
137128
) -> AppResult<User> {
138-
use diesel::prelude::*;
139-
140-
NewUser::new(
129+
let new_user = NewUser::new(
141130
user.id,
142131
&user.login,
143132
user.name.as_deref(),
144133
user.avatar_url.as_deref(),
145134
access_token,
146-
)
147-
.create_or_update(user.email.as_deref(), emails, conn)
148-
.or_else(|e| {
149-
// If we're in read only mode, we can't update their details
150-
// just look for an existing user
151-
if is_read_only_error(&e) {
152-
users::table
153-
.filter(users::gh_id.eq(user.id))
154-
.first(conn)
155-
.optional()?
156-
.ok_or(e)
157-
} else {
158-
Err(e)
135+
);
136+
137+
match new_user
138+
.create_or_update(user.email.as_deref(), emails, conn)
139+
.await
140+
{
141+
Ok(user) => Ok(user),
142+
Err(error) if is_read_only_error(&error) => {
143+
// If we're in read only mode, we can't update their details
144+
// just look for an existing user
145+
find_user_by_gh_id(conn, user.id)
146+
.await?
147+
.ok_or_else(|| error.into())
159148
}
160-
})
161-
.map_err(Into::into)
149+
Err(error) => Err(error.into()),
150+
}
151+
}
152+
153+
async fn find_user_by_gh_id(conn: &mut AsyncPgConnection, gh_id: i32) -> QueryResult<Option<User>> {
154+
users::table
155+
.filter(users::gh_id.eq(gh_id))
156+
.first(conn)
157+
.await
158+
.optional()
162159
}
163160

164161
/// Handles the `DELETE /api/private/session` route.
@@ -170,20 +167,24 @@ pub async fn logout(session: SessionExtension) -> Json<bool> {
170167
#[cfg(test)]
171168
mod tests {
172169
use super::*;
173-
use crate::test_util::test_db_connection;
170+
use crates_io_test_db::TestDatabase;
171+
use diesel_async::AsyncConnection;
174172

175-
#[test]
176-
fn gh_user_with_invalid_email_doesnt_fail() {
173+
#[tokio::test]
174+
async fn gh_user_with_invalid_email_doesnt_fail() {
177175
let emails = Emails::new_in_memory();
178-
let (_test_db, conn) = &mut test_db_connection();
176+
177+
let test_db = TestDatabase::new();
178+
let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap();
179+
179180
let gh_user = GithubUser {
180181
email: Some("String.Format(\"{0}.{1}@live.com\", FirstName, LastName)".into()),
181182
name: Some("My Name".into()),
182183
login: "github_user".into(),
183184
id: -1,
184185
avatar_url: None,
185186
};
186-
let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, conn);
187+
let result = save_user_to_database(&gh_user, "arbitrary_token", &emails, &mut conn).await;
187188

188189
assert!(
189190
result.is_ok(),

src/models/user.rs

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use chrono::NaiveDateTime;
2-
use diesel_async::AsyncPgConnection;
2+
use diesel_async::scoped_futures::ScopedFutureExt;
3+
use diesel_async::{AsyncConnection, AsyncPgConnection};
34
use secrecy::SecretString;
45

56
use crate::app::App;
@@ -171,66 +172,72 @@ impl<'a> NewUser<'a> {
171172
}
172173

173174
/// Inserts the user into the database, or updates an existing one.
174-
pub fn create_or_update(
175+
pub async fn create_or_update(
175176
&self,
176177
email: Option<&'a str>,
177178
emails: &Emails,
178-
conn: &mut impl Conn,
179+
conn: &mut AsyncPgConnection,
179180
) -> QueryResult<User> {
180181
use diesel::dsl::sql;
181182
use diesel::insert_into;
182183
use diesel::pg::upsert::excluded;
183184
use diesel::sql_types::Integer;
184-
use diesel::RunQueryDsl;
185+
use diesel_async::RunQueryDsl;
185186

186187
conn.transaction(|conn| {
187-
let user: User = insert_into(users::table)
188-
.values(self)
189-
// We need the `WHERE gh_id > 0` condition here because `gh_id` set
190-
// to `-1` indicates that we were unable to find a GitHub ID for
191-
// the associated GitHub login at the time that we backfilled
192-
// GitHub IDs. Therefore, there are multiple records in production
193-
// that have a `gh_id` of `-1` so we need to exclude those when
194-
// considering uniqueness of `gh_id` values. The `> 0` condition isn't
195-
// necessary for most fields in the database to be used as a conflict
196-
// target :)
197-
.on_conflict(sql::<Integer>("(gh_id) WHERE gh_id > 0"))
198-
.do_update()
199-
.set((
200-
users::gh_login.eq(excluded(users::gh_login)),
201-
users::name.eq(excluded(users::name)),
202-
users::gh_avatar.eq(excluded(users::gh_avatar)),
203-
users::gh_access_token.eq(excluded(users::gh_access_token)),
204-
))
205-
.get_result(conn)?;
206-
207-
// To send the user an account verification email
208-
if let Some(user_email) = email {
209-
let new_email = NewEmail {
210-
user_id: user.id,
211-
email: user_email,
212-
};
213-
214-
let token = insert_into(emails::table)
215-
.values(&new_email)
216-
.on_conflict_do_nothing()
217-
.returning(emails::token)
218-
.get_result::<String>(conn)
219-
.optional()?
220-
.map(SecretString::from);
221-
222-
if let Some(token) = token {
223-
// Swallows any error. Some users might insert an invalid email address here.
224-
let email = UserConfirmEmail {
225-
user_name: &user.gh_login,
226-
domain: &emails.domain,
227-
token,
188+
async move {
189+
let user: User = insert_into(users::table)
190+
.values(self)
191+
// We need the `WHERE gh_id > 0` condition here because `gh_id` set
192+
// to `-1` indicates that we were unable to find a GitHub ID for
193+
// the associated GitHub login at the time that we backfilled
194+
// GitHub IDs. Therefore, there are multiple records in production
195+
// that have a `gh_id` of `-1` so we need to exclude those when
196+
// considering uniqueness of `gh_id` values. The `> 0` condition isn't
197+
// necessary for most fields in the database to be used as a conflict
198+
// target :)
199+
.on_conflict(sql::<Integer>("(gh_id) WHERE gh_id > 0"))
200+
.do_update()
201+
.set((
202+
users::gh_login.eq(excluded(users::gh_login)),
203+
users::name.eq(excluded(users::name)),
204+
users::gh_avatar.eq(excluded(users::gh_avatar)),
205+
users::gh_access_token.eq(excluded(users::gh_access_token)),
206+
))
207+
.get_result(conn)
208+
.await?;
209+
210+
// To send the user an account verification email
211+
if let Some(user_email) = email {
212+
let new_email = NewEmail {
213+
user_id: user.id,
214+
email: user_email,
228215
};
229-
let _ = emails.send(user_email, email);
216+
217+
let token = insert_into(emails::table)
218+
.values(&new_email)
219+
.on_conflict_do_nothing()
220+
.returning(emails::token)
221+
.get_result::<String>(conn)
222+
.await
223+
.optional()?
224+
.map(SecretString::from);
225+
226+
if let Some(token) = token {
227+
// Swallows any error. Some users might insert an invalid email address here.
228+
let email = UserConfirmEmail {
229+
user_name: &user.gh_login,
230+
domain: &emails.domain,
231+
token,
232+
};
233+
let _ = emails.async_send(user_email, email).await;
234+
}
230235
}
231-
}
232236

233-
Ok(user)
237+
Ok(user)
238+
}
239+
.scope_boxed()
234240
})
241+
.await
235242
}
236243
}

0 commit comments

Comments
 (0)