Skip to content

Commit 345f6f2

Browse files
committed
Use the new RepositoryFactory everywhere
1 parent 626c9be commit 345f6f2

File tree

10 files changed

+64
-75
lines changed

10 files changed

+64
-75
lines changed

crates/cli/src/commands/debug.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use figment::Figment;
1111
use mas_config::{
1212
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
1313
};
14+
use mas_storage_pg::PgRepositoryFactory;
1415
use tracing::{info, info_span};
1516

1617
use crate::util::{
@@ -48,7 +49,8 @@ impl Options {
4849
if with_dynamic_data {
4950
let database_config = DatabaseConfig::extract(figment)?;
5051
let pool = database_pool_from_config(&database_config).await?;
51-
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
52+
let repository_factory = PgRepositoryFactory::new(pool.clone());
53+
load_policy_factory_dynamic_data(&policy_factory, &repository_factory).await?;
5254
}
5355

5456
let _instance = policy_factory.instantiate().await?;

crates/cli/src/commands/server.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl Options {
134134

135135
load_policy_factory_dynamic_data_continuously(
136136
&policy_factory,
137-
&pool,
137+
PgRepositoryFactory::new(pool.clone()).boxed(),
138138
shutdown.soft_shutdown_token(),
139139
shutdown.task_tracker(),
140140
)
@@ -172,7 +172,7 @@ impl Options {
172172

173173
info!("Starting task worker");
174174
mas_tasks::init(
175-
&pool,
175+
PgRepositoryFactory::new(pool.clone()),
176176
&mailer,
177177
homeserver_connection.clone(),
178178
url_builder.clone(),
@@ -193,7 +193,7 @@ impl Options {
193193
// Initialize the activity tracker
194194
// Activity is flushed every minute
195195
let activity_tracker = ActivityTracker::new(
196-
pool.clone(),
196+
PgRepositoryFactory::new(pool.clone()).boxed(),
197197
Duration::from_secs(60),
198198
shutdown.task_tracker(),
199199
shutdown.soft_shutdown_token(),
@@ -215,7 +215,7 @@ impl Options {
215215
limiter.start();
216216

217217
let graphql_schema = mas_handlers::graphql_schema(
218-
&pool,
218+
PgRepositoryFactory::new(pool.clone()).boxed(),
219219
&policy_factory,
220220
homeserver_connection.clone(),
221221
site_config.clone(),

crates/cli/src/commands/worker.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use clap::Parser;
1010
use figment::Figment;
1111
use mas_config::{AppConfig, ConfigurationSection};
1212
use mas_router::UrlBuilder;
13+
use mas_storage_pg::PgRepositoryFactory;
1314
use tracing::{info, info_span};
1415

1516
use crate::{
@@ -63,7 +64,7 @@ impl Options {
6364

6465
info!("Starting task scheduler");
6566
mas_tasks::init(
66-
&pool,
67+
PgRepositoryFactory::new(pool.clone()),
6768
&mailer,
6869
conn,
6970
url_builder,

crates/cli/src/util.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
2020
use mas_matrix_synapse::SynapseConnection;
2121
use mas_policy::PolicyFactory;
2222
use mas_router::UrlBuilder;
23-
use mas_storage::RepositoryAccess;
24-
use mas_storage_pg::PgRepository;
23+
use mas_storage::{BoxRepositoryFactory, RepositoryAccess, RepositoryFactory};
2524
use mas_templates::{SiteConfigExt, Templates};
2625
use sqlx::{
2726
ConnectOptions, Executor, PgConnection, PgPool,
@@ -400,14 +399,13 @@ pub async fn database_connection_from_config_with_options(
400399
// XXX: this could be put somewhere else?
401400
pub async fn load_policy_factory_dynamic_data_continuously(
402401
policy_factory: &Arc<PolicyFactory>,
403-
pool: &PgPool,
402+
repository_factory: BoxRepositoryFactory,
404403
cancellation_token: CancellationToken,
405404
task_tracker: &TaskTracker,
406405
) -> Result<(), anyhow::Error> {
407406
let policy_factory = policy_factory.clone();
408-
let pool = pool.clone();
409407

410-
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
408+
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await?;
411409

412410
task_tracker.spawn(async move {
413411
let mut interval = tokio::time::interval(Duration::from_secs(60));
@@ -420,7 +418,9 @@ pub async fn load_policy_factory_dynamic_data_continuously(
420418
_ = interval.tick() => {}
421419
}
422420

423-
if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await {
421+
if let Err(err) =
422+
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await
423+
{
424424
tracing::error!(
425425
error = ?err,
426426
"Failed to load policy factory dynamic data"
@@ -438,9 +438,10 @@ pub async fn load_policy_factory_dynamic_data_continuously(
438438
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all)]
439439
pub async fn load_policy_factory_dynamic_data(
440440
policy_factory: &PolicyFactory,
441-
pool: &PgPool,
441+
repository_factory: &(dyn RepositoryFactory + Send + Sync),
442442
) -> Result<(), anyhow::Error> {
443-
let mut repo = PgRepository::from_pool(pool)
443+
let mut repo = repository_factory
444+
.create()
444445
.await
445446
.context("Failed to acquire database connection")?;
446447

crates/handlers/src/activity_tracker/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ use std::net::IpAddr;
1111

1212
use chrono::{DateTime, Utc};
1313
use mas_data_model::{BrowserSession, CompatSession, Session};
14-
use mas_storage::Clock;
15-
use sqlx::PgPool;
14+
use mas_storage::{BoxRepositoryFactory, Clock};
1615
use tokio_util::{sync::CancellationToken, task::TaskTracker};
1716
use ulid::Ulid;
1817

@@ -61,12 +60,12 @@ impl ActivityTracker {
6160
/// time, when the cancellation token is cancelled.
6261
#[must_use]
6362
pub fn new(
64-
pool: PgPool,
63+
repository_factory: BoxRepositoryFactory,
6564
flush_interval: std::time::Duration,
6665
task_tracker: &TaskTracker,
6766
cancellation_token: CancellationToken,
6867
) -> Self {
69-
let worker = Worker::new(pool);
68+
let worker = Worker::new(repository_factory);
7069
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
7170
let tracker = ActivityTracker { channel: sender };
7271

crates/handlers/src/activity_tracker/worker.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
use std::{collections::HashMap, net::IpAddr};
88

99
use chrono::{DateTime, Utc};
10-
use mas_storage::{RepositoryAccess, RepositoryError, user::BrowserSessionRepository};
10+
use mas_storage::{
11+
BoxRepositoryFactory, RepositoryAccess, RepositoryError, user::BrowserSessionRepository,
12+
};
1113
use opentelemetry::{
1214
Key, KeyValue,
1315
metrics::{Counter, Gauge, Histogram},
1416
};
15-
use sqlx::PgPool;
1617
use tokio_util::sync::CancellationToken;
1718
use ulid::Ulid;
1819

@@ -43,15 +44,15 @@ struct ActivityRecord {
4344

4445
/// Handles writing activity records to the database.
4546
pub struct Worker {
46-
pool: PgPool,
47+
repository_factory: BoxRepositoryFactory,
4748
pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
4849
pending_records_gauge: Gauge<u64>,
4950
message_counter: Counter<u64>,
5051
flush_time_histogram: Histogram<u64>,
5152
}
5253

5354
impl Worker {
54-
pub(crate) fn new(pool: PgPool) -> Self {
55+
pub(crate) fn new(repository_factory: BoxRepositoryFactory) -> Self {
5556
let message_counter = METER
5657
.u64_counter("mas.activity_tracker.messages")
5758
.with_description("The number of messages received by the activity tracker")
@@ -89,7 +90,7 @@ impl Worker {
8990
pending_records_gauge.record(0, &[]);
9091

9192
Self {
92-
pool,
93+
repository_factory,
9394
pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
9495
pending_records_gauge,
9596
message_counter,
@@ -218,11 +219,7 @@ impl Worker {
218219
#[tracing::instrument(name = "activity_tracker.flush", skip(self))]
219220
async fn try_flush(&mut self) -> Result<(), RepositoryError> {
220221
let pending_records = &self.pending_records;
221-
222-
let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
223-
.await
224-
.map_err(RepositoryError::from_error)?
225-
.boxed();
222+
let mut repo = self.repository_factory.create().await?;
226223

227224
let mut browser_sessions = Vec::new();
228225
let mut oauth2_sessions = Vec::new();

crates/handlers/src/graphql/mod.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ use mas_data_model::{BrowserSession, Session, SiteConfig, User};
3232
use mas_matrix::HomeserverConnection;
3333
use mas_policy::{InstantiateError, Policy, PolicyFactory};
3434
use mas_router::UrlBuilder;
35-
use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
36-
use mas_storage_pg::PgRepository;
35+
use mas_storage::{
36+
BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, Clock, RepositoryError, SystemClock,
37+
};
3738
use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
3839
use rand::{SeedableRng, thread_rng};
3940
use rand_chacha::ChaChaRng;
40-
use sqlx::PgPool;
4141
use state::has_session_ended;
4242
use tracing::{Instrument, info_span};
4343
use ulid::Ulid;
@@ -69,7 +69,7 @@ pub struct ExtraRouterParameters {
6969
}
7070

7171
struct GraphQLState {
72-
pool: PgPool,
72+
repository_factory: BoxRepositoryFactory,
7373
homeserver_connection: Arc<dyn HomeserverConnection>,
7474
policy_factory: Arc<PolicyFactory>,
7575
site_config: SiteConfig,
@@ -81,11 +81,7 @@ struct GraphQLState {
8181
#[async_trait::async_trait]
8282
impl state::State for GraphQLState {
8383
async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
84-
let repo = PgRepository::from_pool(&self.pool)
85-
.await
86-
.map_err(RepositoryError::from_error)?;
87-
88-
Ok(repo.boxed())
84+
self.repository_factory.create().await
8985
}
9086

9187
async fn policy(&self) -> Result<Policy, InstantiateError> {
@@ -128,7 +124,7 @@ impl state::State for GraphQLState {
128124

129125
#[must_use]
130126
pub fn schema(
131-
pool: &PgPool,
127+
repository_factory: BoxRepositoryFactory,
132128
policy_factory: &Arc<PolicyFactory>,
133129
homeserver_connection: impl HomeserverConnection + 'static,
134130
site_config: SiteConfig,
@@ -137,7 +133,7 @@ pub fn schema(
137133
limiter: Limiter,
138134
) -> Schema {
139135
let state = GraphQLState {
140-
pool: pool.clone(),
136+
repository_factory,
141137
policy_factory: Arc::clone(policy_factory),
142138
homeserver_connection: Arc::new(homeserver_connection),
143139
site_config,

crates/handlers/src/test_utils.rs

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
3434
use mas_matrix::{HomeserverConnection, MockHomeserverConnection};
3535
use mas_policy::{InstantiateError, Policy, PolicyFactory};
3636
use mas_router::{SimpleRoute, UrlBuilder};
37-
use mas_storage::{BoxClock, BoxRepository, BoxRng, clock::MockClock};
38-
use mas_storage_pg::{DatabaseError, PgRepository};
37+
use mas_storage::{
38+
BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, RepositoryError, RepositoryFactory,
39+
clock::MockClock,
40+
};
41+
use mas_storage_pg::PgRepositoryFactory;
3942
use mas_templates::{SiteConfigExt, Templates};
4043
use oauth2_types::{registration::ClientRegistrationResponse, requests::AccessTokenResponse};
4144
use rand::SeedableRng;
@@ -92,7 +95,7 @@ pub(crate) async fn policy_factory(
9295

9396
#[derive(Clone)]
9497
pub(crate) struct TestState {
95-
pub pool: PgPool,
98+
pub repository_factory: PgRepositoryFactory,
9699
pub templates: Templates,
97100
pub key_store: Keystore,
98101
pub cookie_manager: CookieManager,
@@ -209,7 +212,7 @@ impl TestState {
209212
let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
210213

211214
let graphql_state = TestGraphQLState {
212-
pool: pool.clone(),
215+
repository_factory: PgRepositoryFactory::new(pool.clone()).boxed(),
213216
policy_factory: Arc::clone(&policy_factory),
214217
homeserver_connection: Arc::clone(&homeserver_connection),
215218
site_config: site_config.clone(),
@@ -224,14 +227,14 @@ impl TestState {
224227
let graphql_schema = graphql::schema_builder().data(state).finish();
225228

226229
let activity_tracker = ActivityTracker::new(
227-
pool.clone(),
230+
PgRepositoryFactory::new(pool.clone()).boxed(),
228231
std::time::Duration::from_secs(60),
229232
&task_tracker,
230233
shutdown_token.child_token(),
231234
);
232235

233236
Ok(Self {
234-
pool,
237+
repository_factory: PgRepositoryFactory::new(pool),
235238
templates,
236239
key_store,
237240
cookie_manager,
@@ -256,7 +259,7 @@ impl TestState {
256259
/// Reset the test utils to a fresh state, with the same configuration.
257260
pub async fn reset(self) -> Self {
258261
let site_config = self.site_config.clone();
259-
let pool = self.pool.clone();
262+
let pool = self.repository_factory.pool();
260263
let task_tracker = self.task_tracker.clone();
261264

262265
// This should trigger the cancellation drop guard
@@ -351,9 +354,8 @@ impl TestState {
351354
access_token
352355
}
353356

354-
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
355-
let repo = PgRepository::from_pool(&self.pool).await?;
356-
Ok(repo.boxed())
357+
pub async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
358+
self.repository_factory.create().await
357359
}
358360

359361
/// Returns a new random number generator.
@@ -393,7 +395,7 @@ impl TestState {
393395
}
394396

395397
struct TestGraphQLState {
396-
pool: PgPool,
398+
repository_factory: BoxRepositoryFactory,
397399
homeserver_connection: Arc<MockHomeserverConnection>,
398400
site_config: SiteConfig,
399401
policy_factory: Arc<PolicyFactory>,
@@ -407,11 +409,7 @@ struct TestGraphQLState {
407409
#[async_trait::async_trait]
408410
impl graphql::State for TestGraphQLState {
409411
async fn repository(&self) -> Result<BoxRepository, mas_storage::RepositoryError> {
410-
let repo = PgRepository::from_pool(&self.pool)
411-
.await
412-
.map_err(mas_storage::RepositoryError::from_error)?;
413-
414-
Ok(repo.boxed())
412+
self.repository_factory.create().await
415413
}
416414

417415
async fn policy(&self) -> Result<Policy, InstantiateError> {
@@ -451,7 +449,7 @@ impl graphql::State for TestGraphQLState {
451449

452450
impl FromRef<TestState> for PgPool {
453451
fn from_ref(input: &TestState) -> Self {
454-
input.pool.clone()
452+
input.repository_factory.pool()
455453
}
456454
}
457455

@@ -598,14 +596,14 @@ impl FromRequestParts<TestState> for BoxRng {
598596
}
599597

600598
impl FromRequestParts<TestState> for BoxRepository {
601-
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
599+
type Rejection = ErrorWrapper<RepositoryError>;
602600

603601
async fn from_request_parts(
604602
_parts: &mut axum::http::request::Parts,
605603
state: &TestState,
606604
) -> Result<Self, Self::Rejection> {
607-
let repo = PgRepository::from_pool(&state.pool).await?;
608-
Ok(repo.boxed())
605+
let repo = state.repository_factory.create().await?;
606+
Ok(repo)
609607
}
610608
}
611609

0 commit comments

Comments
 (0)