Skip to content

Commit 03bad37

Browse files
committed
Introduce a RepositoryFactory
1 parent ad66524 commit 03bad37

File tree

6 files changed

+80
-18
lines changed

6 files changed

+80
-18
lines changed

crates/cli/src/app_state.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ use mas_keystore::{Encrypter, Keystore};
1919
use mas_matrix::HomeserverConnection;
2020
use mas_policy::{Policy, PolicyFactory};
2121
use mas_router::UrlBuilder;
22-
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
23-
use mas_storage_pg::PgRepository;
22+
use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, SystemClock, RepositoryFactory};
23+
use mas_storage_pg::PgRepositoryFactory;
2424
use mas_templates::Templates;
2525
use opentelemetry::{KeyValue, metrics::Histogram};
2626
use rand::SeedableRng;
@@ -31,7 +31,7 @@ use crate::telemetry::METER;
3131

3232
#[derive(Clone)]
3333
pub struct AppState {
34-
pub pool: PgPool,
34+
pub repository_factory: PgRepositoryFactory,
3535
pub templates: Templates,
3636
pub key_store: Keystore,
3737
pub cookie_manager: CookieManager,
@@ -53,7 +53,7 @@ pub struct AppState {
5353
impl AppState {
5454
/// Init the metrics for the app state.
5555
pub fn init_metrics(&mut self) {
56-
let pool = self.pool.clone();
56+
let pool = self.repository_factory.pool();
5757
METER
5858
.i64_observable_up_down_counter("db.connections.usage")
5959
.with_description("The number of connections that are currently in `state` described by the state attribute.")
@@ -66,7 +66,7 @@ impl AppState {
6666
})
6767
.build();
6868

69-
let pool = self.pool.clone();
69+
let pool = self.repository_factory.pool();
7070
METER
7171
.i64_observable_up_down_counter("db.connections.max")
7272
.with_description("The maximum number of open connections allowed.")
@@ -88,14 +88,14 @@ impl AppState {
8888

8989
/// Init the metadata cache in the background
9090
pub fn init_metadata_cache(&self) {
91-
let pool = self.pool.clone();
91+
let factory = self.repository_factory.clone();
9292
let metadata_cache = self.metadata_cache.clone();
9393
let http_client = self.http_client.clone();
9494

9595
tokio::spawn(
9696
LogContext::new("metadata-cache-warmup")
9797
.run(async move || {
98-
let conn = match pool.acquire().await {
98+
let mut repo = match factory.create().await {
9999
Ok(conn) => conn,
100100
Err(e) => {
101101
tracing::error!(
@@ -106,8 +106,6 @@ impl AppState {
106106
}
107107
};
108108

109-
let mut repo = PgRepository::from_conn(conn);
110-
111109
if let Err(e) = metadata_cache
112110
.warm_up_and_run(
113111
&http_client,
@@ -127,9 +125,17 @@ impl AppState {
127125
}
128126
}
129127

128+
// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
129+
// should be part of the repository
130130
impl FromRef<AppState> for PgPool {
131131
fn from_ref(input: &AppState) -> Self {
132-
input.pool.clone()
132+
input.repository_factory.pool()
133+
}
134+
}
135+
136+
impl FromRef<AppState> for BoxRepositoryFactory {
137+
fn from_ref(input: &AppState) -> Self {
138+
input.repository_factory.clone().boxed()
133139
}
134140
}
135141

@@ -359,14 +365,14 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
359365
}
360366

361367
impl FromRequestParts<AppState> for BoxRepository {
362-
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
368+
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;
363369

364370
async fn from_request_parts(
365371
_parts: &mut axum::http::request::Parts,
366372
state: &AppState,
367373
) -> Result<Self, Self::Rejection> {
368374
let start = Instant::now();
369-
let repo = PgRepository::from_pool(&state.pool).await?;
375+
let repo = state.repository_factory.create().await?;
370376

371377
// Measure the time it took to create the connection
372378
let duration = start.elapsed();
@@ -376,6 +382,6 @@ impl FromRequestParts<AppState> for BoxRepository {
376382
histogram.record(duration_ms, &[]);
377383
}
378384

379-
Ok(repo.boxed())
385+
Ok(repo)
380386
}
381387
}

crates/cli/src/commands/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
1818
use mas_listener::server::Server;
1919
use mas_router::UrlBuilder;
2020
use mas_storage::SystemClock;
21-
use mas_storage_pg::MIGRATOR;
21+
use mas_storage_pg::{PgRepositoryFactory, MIGRATOR};
2222
use sqlx::migrate::Migrate;
2323
use tracing::{Instrument, info, info_span, warn};
2424

@@ -226,7 +226,7 @@ impl Options {
226226

227227
let state = {
228228
let mut s = AppState {
229-
pool,
229+
repository_factory: PgRepositoryFactory::new(pool),
230230
templates,
231231
key_store,
232232
cookie_manager,

crates/storage-pg/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ pub(crate) mod repository;
178178
pub(crate) mod tracing;
179179

180180
pub(crate) use self::errors::DatabaseInconsistencyError;
181-
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
181+
pub use self::{
182+
errors::DatabaseError,
183+
repository::{PgRepository, PgRepositoryFactory},
184+
tracing::ExecuteExt,
185+
};
182186

183187
/// Embedded migrations, allowing them to run on startup
184188
pub static MIGRATOR: Migrator = {

crates/storage-pg/src/repository.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
use std::ops::{Deref, DerefMut};
88

9+
use async_trait::async_trait;
910
use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
1011
use mas_storage::{
11-
BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
12+
BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
13+
RepositoryFactory, RepositoryTransaction,
1214
app_session::AppSessionRepository,
1315
compat::{
1416
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@@ -57,6 +59,43 @@ use crate::{
5759
},
5860
};
5961

62+
/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
63+
/// connection pool.
64+
#[derive(Clone)]
65+
pub struct PgRepositoryFactory {
66+
pool: PgPool,
67+
}
68+
69+
impl PgRepositoryFactory {
70+
/// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
71+
#[must_use]
72+
pub fn new(pool: PgPool) -> Self {
73+
Self { pool }
74+
}
75+
76+
/// Box the factory
77+
#[must_use]
78+
pub fn boxed(self) -> BoxRepositoryFactory {
79+
Box::new(self)
80+
}
81+
82+
/// Get the underlying PostgreSQL connection pool
83+
#[must_use]
84+
pub fn pool(&self) -> PgPool {
85+
self.pool.clone()
86+
}
87+
}
88+
89+
#[async_trait]
90+
impl RepositoryFactory for PgRepositoryFactory {
91+
async fn create(&self) -> Result<BoxRepository, RepositoryError> {
92+
Ok(PgRepository::from_pool(&self.pool)
93+
.await
94+
.map_err(RepositoryError::from_error)?
95+
.boxed())
96+
}
97+
}
98+
6099
/// An implementation of the [`Repository`] trait backed by a PostgreSQL
61100
/// transaction.
62101
pub struct PgRepository<C = Transaction<'static, Postgres>> {

crates/storage/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ pub use self::{
128128
clock::{Clock, SystemClock},
129129
pagination::{Page, Pagination},
130130
repository::{
131-
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
131+
BoxRepository, BoxRepositoryFactory, Repository, RepositoryAccess, RepositoryError,
132+
RepositoryFactory, RepositoryTransaction,
132133
},
133134
utils::{BoxClock, BoxRng, MapErr},
134135
};

crates/storage/src/repository.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// SPDX-License-Identifier: AGPL-3.0-only
55
// Please see LICENSE in the repository root for full details.
66

7+
use async_trait::async_trait;
78
use futures_util::future::BoxFuture;
89
use thiserror::Error;
910

@@ -29,6 +30,17 @@ use crate::{
2930
},
3031
};
3132

33+
/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`]
34+
// XXX(quenting): this could be generic over the repository type, but it's annoying to make it dyn-safe
35+
#[async_trait]
36+
pub trait RepositoryFactory {
37+
/// Create a new [`BoxRepository`]
38+
async fn create(&self) -> Result<BoxRepository, RepositoryError>;
39+
}
40+
41+
/// A type-erased [`RepositoryFactory`]
42+
pub type BoxRepositoryFactory = Box<dyn RepositoryFactory + Send + Sync + 'static>;
43+
3244
/// A [`Repository`] helps interacting with the underlying storage backend.
3345
pub trait Repository<E>:
3446
RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send

0 commit comments

Comments
 (0)