4
4
// SPDX-License-Identifier: AGPL-3.0-only
5
5
// Please see LICENSE in the repository root for full details.
6
6
7
- use std:: { convert:: Infallible , net:: IpAddr , sync:: Arc , time :: Instant } ;
7
+ use std:: { convert:: Infallible , net:: IpAddr , sync:: Arc } ;
8
8
9
9
use axum:: extract:: { FromRef , FromRequestParts } ;
10
10
use ipnetwork:: IpNetwork ;
@@ -19,10 +19,12 @@ use mas_keystore::{Encrypter, Keystore};
19
19
use mas_matrix:: HomeserverConnection ;
20
20
use mas_policy:: { Policy , PolicyFactory } ;
21
21
use mas_router:: UrlBuilder ;
22
- use mas_storage:: { BoxClock , BoxRepository , BoxRng , SystemClock } ;
23
- use mas_storage_pg:: PgRepository ;
22
+ use mas_storage:: {
23
+ BoxClock , BoxRepository , BoxRepositoryFactory , BoxRng , RepositoryFactory , SystemClock ,
24
+ } ;
25
+ use mas_storage_pg:: PgRepositoryFactory ;
24
26
use mas_templates:: Templates ;
25
- use opentelemetry:: { KeyValue , metrics :: Histogram } ;
27
+ use opentelemetry:: KeyValue ;
26
28
use rand:: SeedableRng ;
27
29
use sqlx:: PgPool ;
28
30
use tracing:: Instrument ;
@@ -31,7 +33,7 @@ use crate::telemetry::METER;
31
33
32
34
#[ derive( Clone ) ]
33
35
pub struct AppState {
34
- pub pool : PgPool ,
36
+ pub repository_factory : PgRepositoryFactory ,
35
37
pub templates : Templates ,
36
38
pub key_store : Keystore ,
37
39
pub cookie_manager : CookieManager ,
@@ -47,13 +49,12 @@ pub struct AppState {
47
49
pub activity_tracker : ActivityTracker ,
48
50
pub trusted_proxies : Vec < IpNetwork > ,
49
51
pub limiter : Limiter ,
50
- pub conn_acquisition_histogram : Option < Histogram < u64 > > ,
51
52
}
52
53
53
54
impl AppState {
54
55
/// Init the metrics for the app state.
55
56
pub fn init_metrics ( & mut self ) {
56
- let pool = self . pool . clone ( ) ;
57
+ let pool = self . repository_factory . pool ( ) ;
57
58
METER
58
59
. i64_observable_up_down_counter ( "db.connections.usage" )
59
60
. with_description ( "The number of connections that are currently in `state` described by the state attribute." )
@@ -66,7 +67,7 @@ impl AppState {
66
67
} )
67
68
. build ( ) ;
68
69
69
- let pool = self . pool . clone ( ) ;
70
+ let pool = self . repository_factory . pool ( ) ;
70
71
METER
71
72
. i64_observable_up_down_counter ( "db.connections.max" )
72
73
. with_description ( "The maximum number of open connections allowed." )
@@ -76,26 +77,18 @@ impl AppState {
76
77
instrument. observe ( i64:: from ( max_conn) , & [ ] ) ;
77
78
} )
78
79
. build ( ) ;
79
-
80
- // Track the connection acquisition time
81
- let histogram = METER
82
- . u64_histogram ( "db.client.connections.create_time" )
83
- . with_description ( "The time it took to create a new connection." )
84
- . with_unit ( "ms" )
85
- . build ( ) ;
86
- self . conn_acquisition_histogram = Some ( histogram) ;
87
80
}
88
81
89
82
/// Init the metadata cache in the background
90
83
pub fn init_metadata_cache ( & self ) {
91
- let pool = self . pool . clone ( ) ;
84
+ let factory = self . repository_factory . clone ( ) ;
92
85
let metadata_cache = self . metadata_cache . clone ( ) ;
93
86
let http_client = self . http_client . clone ( ) ;
94
87
95
88
tokio:: spawn (
96
89
LogContext :: new ( "metadata-cache-warmup" )
97
90
. run ( async move || {
98
- let conn = match pool . acquire ( ) . await {
91
+ let mut repo = match factory . create ( ) . await {
99
92
Ok ( conn) => conn,
100
93
Err ( e) => {
101
94
tracing:: error!(
@@ -106,8 +99,6 @@ impl AppState {
106
99
}
107
100
} ;
108
101
109
- let mut repo = PgRepository :: from_conn ( conn) ;
110
-
111
102
if let Err ( e) = metadata_cache
112
103
. warm_up_and_run (
113
104
& http_client,
@@ -127,9 +118,17 @@ impl AppState {
127
118
}
128
119
}
129
120
121
+ // XXX(quenting): we only use this for the healthcheck endpoint, checking the db
122
+ // should be part of the repository
130
123
impl FromRef < AppState > for PgPool {
131
124
fn from_ref ( input : & AppState ) -> Self {
132
- input. pool . clone ( )
125
+ input. repository_factory . pool ( )
126
+ }
127
+ }
128
+
129
+ impl FromRef < AppState > for BoxRepositoryFactory {
130
+ fn from_ref ( input : & AppState ) -> Self {
131
+ input. repository_factory . clone ( ) . boxed ( )
133
132
}
134
133
}
135
134
@@ -359,23 +358,13 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
359
358
}
360
359
361
360
impl FromRequestParts < AppState > for BoxRepository {
362
- type Rejection = ErrorWrapper < mas_storage_pg :: DatabaseError > ;
361
+ type Rejection = ErrorWrapper < mas_storage :: RepositoryError > ;
363
362
364
363
async fn from_request_parts (
365
364
_parts : & mut axum:: http:: request:: Parts ,
366
365
state : & AppState ,
367
366
) -> Result < Self , Self :: Rejection > {
368
- let start = Instant :: now ( ) ;
369
- let repo = PgRepository :: from_pool ( & state. pool ) . await ?;
370
-
371
- // Measure the time it took to create the connection
372
- let duration = start. elapsed ( ) ;
373
- let duration_ms = duration. as_millis ( ) . try_into ( ) . unwrap_or ( u64:: MAX ) ;
374
-
375
- if let Some ( histogram) = & state. conn_acquisition_histogram {
376
- histogram. record ( duration_ms, & [ ] ) ;
377
- }
378
-
379
- Ok ( repo. boxed ( ) )
367
+ let repo = state. repository_factory . create ( ) . await ?;
368
+ Ok ( repo)
380
369
}
381
370
}
0 commit comments