Skip to content

Commit 8d7be72

Browse files
committed
Don't hold db conns when creating a device on the compat login API
1 parent 345f6f2 commit 8d7be72

File tree

3 files changed

+65
-33
lines changed

3 files changed

+65
-33
lines changed

crates/handlers/src/compat/login.rs

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ use mas_axum_utils::record_error;
1414
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, SiteConfig, TokenType, User};
1515
use mas_matrix::HomeserverConnection;
1616
use mas_storage::{
17-
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
17+
BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, Clock, RepositoryAccess,
1818
compat::{
1919
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
2020
CompatSsoLoginRepository,
2121
},
22+
queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
2223
user::{UserPasswordRepository, UserRepository},
2324
};
2425
use opentelemetry::{Key, KeyValue, metrics::Counter};
@@ -268,7 +269,7 @@ pub(crate) async fn post(
268269
mut rng: BoxRng,
269270
clock: BoxClock,
270271
State(password_manager): State<PasswordManager>,
271-
mut repo: BoxRepository,
272+
State(repository_factory): State<BoxRepositoryFactory>,
272273
activity_tracker: BoundActivityTracker,
273274
State(homeserver): State<Arc<dyn HomeserverConnection>>,
274275
State(site_config): State<SiteConfig>,
@@ -279,6 +280,7 @@ pub(crate) async fn post(
279280
) -> Result<impl IntoResponse, RouteError> {
280281
let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
281282
let login_type = input.credentials.login_type();
283+
let mut repo = repository_factory.create().await?;
282284
let (mut session, user) = match (password_manager.is_enabled(), input.credentials) {
283285
(
284286
true,
@@ -301,15 +303,17 @@ pub(crate) async fn post(
301303
}
302304
};
303305

306+
// Try getting the localpart out of the MXID
307+
let username = homeserver.localpart(&user).unwrap_or(&user);
308+
304309
user_password_login(
305310
&mut rng,
306311
&clock,
307312
&password_manager,
308313
&limiter,
309314
requester,
310315
&mut repo,
311-
&homeserver,
312-
user,
316+
username,
313317
password,
314318
input.device_id, // TODO check for validity
315319
input.initial_device_display_name,
@@ -322,7 +326,6 @@ pub(crate) async fn post(
322326
&mut rng,
323327
&clock,
324328
&mut repo,
325-
&homeserver,
326329
&token,
327330
input.device_id,
328331
input.initial_device_display_name,
@@ -368,12 +371,53 @@ pub(crate) async fn post(
368371
None
369372
};
370373

374+
// Ideally, we'd keep the lock whilst we actually create the device, but we
375+
// really want to stop holding the transaction while we talk to the
376+
// homeserver.
377+
//
378+
// In practice, this is fine, because:
379+
// - the session exists after we commited the transaction, so a sync job won't
380+
// try to delete it
381+
// - we've acquired a lock on the user before creating the session, meaning
382+
// we've made sure that sync jobs finished before we create the new session
383+
// - we're in the read-commited isolation level, which means the sync will see
384+
// what we've committed and won't try to delete the session once we release
385+
// the lock
371386
repo.save().await?;
372387

373388
activity_tracker
374389
.record_compat_session(&clock, &session)
375390
.await;
376391

392+
// This session will have for sure the device on it, both methods create a
393+
// device
394+
let Some(device) = &session.device else {
395+
unreachable!()
396+
};
397+
398+
// Now we can create the device on the homeserver, without holding the
399+
// transaction
400+
if let Err(err) = homeserver
401+
.create_device(&user_id, device.as_str(), session.human_name.as_deref())
402+
.await
403+
{
404+
// Something went wrong, let's end this session and schedule a device sync
405+
let mut repo = repository_factory.create().await?;
406+
let session = repo.compat_session().finish(&clock, session).await?;
407+
408+
repo.queue_job()
409+
.schedule_job(
410+
&mut rng,
411+
&clock,
412+
SyncDevicesJob::new_for_id(session.user_id),
413+
)
414+
.await?;
415+
416+
repo.save().await?;
417+
418+
return Err(RouteError::ProvisionDeviceFailed(err));
419+
}
420+
377421
LOGIN_COUNTER.add(
378422
1,
379423
&[
@@ -395,7 +439,6 @@ async fn token_login(
395439
rng: &mut (dyn RngCore + Send),
396440
clock: &dyn Clock,
397441
repo: &mut BoxRepository,
398-
homeserver: &dyn HomeserverConnection,
399442
token: &str,
400443
requested_device_id: Option<String>,
401444
initial_device_display_name: Option<String>,
@@ -461,7 +504,8 @@ async fn token_login(
461504
return Err(RouteError::InvalidLoginToken);
462505
}
463506

464-
// Lock the user sync to make sure we don't get into a race condition
507+
// We're about to create a device, let's explicitly acquire a lock, so that
508+
// any concurrent sync will read after we've committed
465509
repo.user()
466510
.acquire_lock_for_sync(&browser_session.user)
467511
.await?;
@@ -471,20 +515,14 @@ async fn token_login(
471515
} else {
472516
Device::generate(rng)
473517
};
474-
let mxid = homeserver.mxid(&browser_session.user.username);
475-
homeserver
476-
.create_device(
477-
&mxid,
478-
device.as_str(),
479-
initial_device_display_name.as_deref(),
480-
)
481-
.await
482-
.map_err(RouteError::ProvisionDeviceFailed)?;
483518

484519
repo.app_session()
485520
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
486521
.await?;
487522

523+
// We first create the session in the database, commit the transaction, then
524+
// create it on the homeserver, scheduling a device sync job afterwards to
525+
// make sure we don't end up in an inconsistent state.
488526
let compat_session = repo
489527
.compat_session()
490528
.add(
@@ -512,15 +550,11 @@ async fn user_password_login(
512550
limiter: &Limiter,
513551
requester: RequesterFingerprint,
514552
repo: &mut BoxRepository,
515-
homeserver: &dyn HomeserverConnection,
516-
username: String,
553+
username: &str,
517554
password: String,
518555
requested_device_id: Option<String>,
519556
initial_device_display_name: Option<String>,
520557
) -> Result<(CompatSession, User), RouteError> {
521-
// Try getting the localpart out of the MXID
522-
let username = homeserver.localpart(&username).unwrap_or(&username);
523-
524558
// Find the user
525559
let user = repo
526560
.user()
@@ -566,25 +600,16 @@ async fn user_password_login(
566600
.await?;
567601
}
568602

569-
// Lock the user sync to make sure we don't get into a race condition
603+
// We're about to create a device, let's explicitly acquire a lock, so that
604+
// any concurrent sync will read after we've committed
570605
repo.user().acquire_lock_for_sync(&user).await?;
571606

572-
let mxid = homeserver.mxid(&user.username);
573-
574607
// Now that the user credentials have been verified, start a new compat session
575608
let device = if let Some(requested_device_id) = requested_device_id {
576609
Device::from(requested_device_id)
577610
} else {
578611
Device::generate(&mut rng)
579612
};
580-
homeserver
581-
.create_device(
582-
&mxid,
583-
device.as_str(),
584-
initial_device_display_name.as_deref(),
585-
)
586-
.await
587-
.map_err(RouteError::ProvisionDeviceFailed)?;
588613

589614
repo.app_session()
590615
.finish_sessions_to_replace_device(clock, &user, &device)

crates/handlers/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use mas_keystore::{Encrypter, Keystore};
4242
use mas_matrix::HomeserverConnection;
4343
use mas_policy::Policy;
4444
use mas_router::{Route, UrlBuilder};
45-
use mas_storage::{BoxClock, BoxRepository, BoxRng};
45+
use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng};
4646
use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
4747
use opentelemetry::metrics::Meter;
4848
use sqlx::PgPool;
@@ -265,6 +265,7 @@ where
265265
Arc<dyn HomeserverConnection>: FromRef<S>,
266266
PasswordManager: FromRef<S>,
267267
Limiter: FromRef<S>,
268+
BoxRepositoryFactory: FromRef<S>,
268269
BoundActivityTracker: FromRequestParts<S>,
269270
RequesterFingerprint: FromRequestParts<S>,
270271
BoxRepository: FromRequestParts<S>,

crates/handlers/src/test_utils.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,12 @@ impl FromRef<TestState> for PgPool {
453453
}
454454
}
455455

456+
impl FromRef<TestState> for BoxRepositoryFactory {
457+
fn from_ref(input: &TestState) -> Self {
458+
input.repository_factory.clone().boxed()
459+
}
460+
}
461+
456462
impl FromRef<TestState> for graphql::Schema {
457463
fn from_ref(input: &TestState) -> Self {
458464
input.graphql_schema.clone()

0 commit comments

Comments
 (0)