@@ -19,8 +19,8 @@ use mas_keystore::{Encrypter, Keystore};
19
19
use mas_matrix:: HomeserverConnection ;
20
20
use mas_policy:: { Policy , PolicyFactory } ;
21
21
use mas_router:: UrlBuilder ;
22
- use mas_storage:: { BoxClock , BoxRepository , BoxRng , SystemClock } ;
23
- use mas_storage_pg:: PgRepository ;
22
+ use mas_storage:: { BoxClock , BoxRepository , BoxRepositoryFactory , BoxRng , SystemClock , RepositoryFactory } ;
23
+ use mas_storage_pg:: PgRepositoryFactory ;
24
24
use mas_templates:: Templates ;
25
25
use opentelemetry:: { KeyValue , metrics:: Histogram } ;
26
26
use rand:: SeedableRng ;
@@ -31,7 +31,7 @@ use crate::telemetry::METER;
31
31
32
32
#[ derive( Clone ) ]
33
33
pub struct AppState {
34
- pub pool : PgPool ,
34
+ pub repository_factory : PgRepositoryFactory ,
35
35
pub templates : Templates ,
36
36
pub key_store : Keystore ,
37
37
pub cookie_manager : CookieManager ,
@@ -53,7 +53,7 @@ pub struct AppState {
53
53
impl AppState {
54
54
/// Init the metrics for the app state.
55
55
pub fn init_metrics ( & mut self ) {
56
- let pool = self . pool . clone ( ) ;
56
+ let pool = self . repository_factory . pool ( ) ;
57
57
METER
58
58
. i64_observable_up_down_counter ( "db.connections.usage" )
59
59
. with_description ( "The number of connections that are currently in `state` described by the state attribute." )
@@ -66,7 +66,7 @@ impl AppState {
66
66
} )
67
67
. build ( ) ;
68
68
69
- let pool = self . pool . clone ( ) ;
69
+ let pool = self . repository_factory . pool ( ) ;
70
70
METER
71
71
. i64_observable_up_down_counter ( "db.connections.max" )
72
72
. with_description ( "The maximum number of open connections allowed." )
@@ -88,14 +88,14 @@ impl AppState {
88
88
89
89
/// Init the metadata cache in the background
90
90
pub fn init_metadata_cache ( & self ) {
91
- let pool = self . pool . clone ( ) ;
91
+ let factory = self . repository_factory . clone ( ) ;
92
92
let metadata_cache = self . metadata_cache . clone ( ) ;
93
93
let http_client = self . http_client . clone ( ) ;
94
94
95
95
tokio:: spawn (
96
96
LogContext :: new ( "metadata-cache-warmup" )
97
97
. run ( async move || {
98
- let conn = match pool . acquire ( ) . await {
98
+ let mut repo = match factory . create ( ) . await {
99
99
Ok ( conn) => conn,
100
100
Err ( e) => {
101
101
tracing:: error!(
@@ -106,8 +106,6 @@ impl AppState {
106
106
}
107
107
} ;
108
108
109
- let mut repo = PgRepository :: from_conn ( conn) ;
110
-
111
109
if let Err ( e) = metadata_cache
112
110
. warm_up_and_run (
113
111
& http_client,
@@ -127,9 +125,17 @@ impl AppState {
127
125
}
128
126
}
129
127
128
+ // XXX(quenting): we only use this for the healthcheck endpoint, checking the db
129
+ // should be part of the repository
130
130
impl FromRef < AppState > for PgPool {
131
131
fn from_ref ( input : & AppState ) -> Self {
132
- input. pool . clone ( )
132
+ input. repository_factory . pool ( )
133
+ }
134
+ }
135
+
136
+ impl FromRef < AppState > for BoxRepositoryFactory {
137
+ fn from_ref ( input : & AppState ) -> Self {
138
+ input. repository_factory . clone ( ) . boxed ( )
133
139
}
134
140
}
135
141
@@ -359,14 +365,14 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
359
365
}
360
366
361
367
impl FromRequestParts < AppState > for BoxRepository {
362
- type Rejection = ErrorWrapper < mas_storage_pg :: DatabaseError > ;
368
+ type Rejection = ErrorWrapper < mas_storage :: RepositoryError > ;
363
369
364
370
async fn from_request_parts (
365
371
_parts : & mut axum:: http:: request:: Parts ,
366
372
state : & AppState ,
367
373
) -> Result < Self , Self :: Rejection > {
368
374
let start = Instant :: now ( ) ;
369
- let repo = PgRepository :: from_pool ( & state. pool ) . await ?;
375
+ let repo = state. repository_factory . create ( ) . await ?;
370
376
371
377
// Measure the time it took to create the connection
372
378
let duration = start. elapsed ( ) ;
@@ -376,6 +382,6 @@ impl FromRequestParts<AppState> for BoxRepository {
376
382
histogram. record ( duration_ms, & [ ] ) ;
377
383
}
378
384
379
- Ok ( repo. boxed ( ) )
385
+ Ok ( repo)
380
386
}
381
387
}
0 commit comments