@@ -34,8 +34,11 @@ use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
34
34
use mas_matrix:: { HomeserverConnection , MockHomeserverConnection } ;
35
35
use mas_policy:: { InstantiateError , Policy , PolicyFactory } ;
36
36
use mas_router:: { SimpleRoute , UrlBuilder } ;
37
- use mas_storage:: { BoxClock , BoxRepository , BoxRng , clock:: MockClock } ;
38
- use mas_storage_pg:: { DatabaseError , PgRepository } ;
37
+ use mas_storage:: {
38
+ BoxClock , BoxRepository , BoxRepositoryFactory , BoxRng , RepositoryError , RepositoryFactory ,
39
+ clock:: MockClock ,
40
+ } ;
41
+ use mas_storage_pg:: PgRepositoryFactory ;
39
42
use mas_templates:: { SiteConfigExt , Templates } ;
40
43
use oauth2_types:: { registration:: ClientRegistrationResponse , requests:: AccessTokenResponse } ;
41
44
use rand:: SeedableRng ;
@@ -92,7 +95,7 @@ pub(crate) async fn policy_factory(
92
95
93
96
#[ derive( Clone ) ]
94
97
pub ( crate ) struct TestState {
95
- pub pool : PgPool ,
98
+ pub repository_factory : PgRepositoryFactory ,
96
99
pub templates : Templates ,
97
100
pub key_store : Keystore ,
98
101
pub cookie_manager : CookieManager ,
@@ -209,7 +212,7 @@ impl TestState {
209
212
let limiter = Limiter :: new ( & RateLimitingConfig :: default ( ) ) . unwrap ( ) ;
210
213
211
214
let graphql_state = TestGraphQLState {
212
- pool : pool. clone ( ) ,
215
+ repository_factory : PgRepositoryFactory :: new ( pool. clone ( ) ) . boxed ( ) ,
213
216
policy_factory : Arc :: clone ( & policy_factory) ,
214
217
homeserver_connection : Arc :: clone ( & homeserver_connection) ,
215
218
site_config : site_config. clone ( ) ,
@@ -224,14 +227,14 @@ impl TestState {
224
227
let graphql_schema = graphql:: schema_builder ( ) . data ( state) . finish ( ) ;
225
228
226
229
let activity_tracker = ActivityTracker :: new (
227
- pool. clone ( ) ,
230
+ PgRepositoryFactory :: new ( pool. clone ( ) ) . boxed ( ) ,
228
231
std:: time:: Duration :: from_secs ( 60 ) ,
229
232
& task_tracker,
230
233
shutdown_token. child_token ( ) ,
231
234
) ;
232
235
233
236
Ok ( Self {
234
- pool,
237
+ repository_factory : PgRepositoryFactory :: new ( pool) ,
235
238
templates,
236
239
key_store,
237
240
cookie_manager,
@@ -256,7 +259,7 @@ impl TestState {
256
259
/// Reset the test utils to a fresh state, with the same configuration.
257
260
pub async fn reset ( self ) -> Self {
258
261
let site_config = self . site_config . clone ( ) ;
259
- let pool = self . pool . clone ( ) ;
262
+ let pool = self . repository_factory . pool ( ) ;
260
263
let task_tracker = self . task_tracker . clone ( ) ;
261
264
262
265
// This should trigger the cancellation drop guard
@@ -351,9 +354,8 @@ impl TestState {
351
354
access_token
352
355
}
353
356
354
- pub async fn repository ( & self ) -> Result < BoxRepository , DatabaseError > {
355
- let repo = PgRepository :: from_pool ( & self . pool ) . await ?;
356
- Ok ( repo. boxed ( ) )
357
+ pub async fn repository ( & self ) -> Result < BoxRepository , RepositoryError > {
358
+ self . repository_factory . create ( ) . await
357
359
}
358
360
359
361
/// Returns a new random number generator.
@@ -393,7 +395,7 @@ impl TestState {
393
395
}
394
396
395
397
struct TestGraphQLState {
396
- pool : PgPool ,
398
+ repository_factory : BoxRepositoryFactory ,
397
399
homeserver_connection : Arc < MockHomeserverConnection > ,
398
400
site_config : SiteConfig ,
399
401
policy_factory : Arc < PolicyFactory > ,
@@ -407,11 +409,7 @@ struct TestGraphQLState {
407
409
#[ async_trait:: async_trait]
408
410
impl graphql:: State for TestGraphQLState {
409
411
async fn repository ( & self ) -> Result < BoxRepository , mas_storage:: RepositoryError > {
410
- let repo = PgRepository :: from_pool ( & self . pool )
411
- . await
412
- . map_err ( mas_storage:: RepositoryError :: from_error) ?;
413
-
414
- Ok ( repo. boxed ( ) )
412
+ self . repository_factory . create ( ) . await
415
413
}
416
414
417
415
async fn policy ( & self ) -> Result < Policy , InstantiateError > {
@@ -451,7 +449,7 @@ impl graphql::State for TestGraphQLState {
451
449
452
450
impl FromRef < TestState > for PgPool {
453
451
fn from_ref ( input : & TestState ) -> Self {
454
- input. pool . clone ( )
452
+ input. repository_factory . pool ( )
455
453
}
456
454
}
457
455
@@ -598,14 +596,14 @@ impl FromRequestParts<TestState> for BoxRng {
598
596
}
599
597
600
598
impl FromRequestParts < TestState > for BoxRepository {
601
- type Rejection = ErrorWrapper < mas_storage_pg :: DatabaseError > ;
599
+ type Rejection = ErrorWrapper < RepositoryError > ;
602
600
603
601
async fn from_request_parts (
604
602
_parts : & mut axum:: http:: request:: Parts ,
605
603
state : & TestState ,
606
604
) -> Result < Self , Self :: Rejection > {
607
- let repo = PgRepository :: from_pool ( & state. pool ) . await ?;
608
- Ok ( repo. boxed ( ) )
605
+ let repo = state. repository_factory . create ( ) . await ?;
606
+ Ok ( repo)
609
607
}
610
608
}
611
609
0 commit comments