@@ -150,7 +150,6 @@ pub struct ClientWithTime {
150150
151151pub struct OidcState {
152152 pub config : OidcConfig ,
153- app_config : AppConfig ,
154153 client : RwLock < ClientWithTime > ,
155154}
156155
@@ -161,45 +160,39 @@ impl OidcState {
161160
162161 Ok ( Self {
163162 config : oidc_cfg,
164- app_config,
165163 client : RwLock :: new ( ClientWithTime {
166164 client,
167165 last_update : Instant :: now ( ) ,
168166 } ) ,
169167 } )
170168 }
171169
172- async fn refresh ( & self ) {
173- let Ok ( http_client) = make_http_client ( & self . app_config ) else {
174- log:: error!( "Failed to create HTTP client" ) ;
175- return ;
176- } ;
177- let mut write_guard = self . client . write ( ) . await ;
178- match build_oidc_client ( & self . config , & http_client) . await {
179- Ok ( client) => {
180- * write_guard = ClientWithTime {
181- client,
170+ async fn refresh ( & self , service_request : & ServiceRequest ) {
171+ match build_oidc_client_from_appdata ( & self . config , service_request) . await {
172+ Ok ( http_client) => {
173+ * self . client . write ( ) . await = ClientWithTime {
174+ client : http_client,
182175 last_update : Instant :: now ( ) ,
183- } ;
184- }
185- Err ( e) => {
186- log:: error!( "Failed to refresh OIDC client: {e}" ) ;
176+ }
187177 }
178+ Err ( e) => log:: error!( "Failed to refresh OIDC client: {e}" ) ,
179+ }
180+ }
181+
182+ /// Refreshes the OIDC client from the provider metadata URL if it has expired.
183+ /// Most providers update their signing keys periodically.
184+ pub async fn refresh_if_expired ( & self , service_request : & ServiceRequest ) {
185+ if self . client . read ( ) . await . last_update . elapsed ( ) > OIDC_CLIENT_REFRESH_INTERVAL {
186+ self . refresh ( service_request) . await ;
188187 }
189188 }
190189
191190 /// Gets a reference to the oidc client, potentially generating a new one if needed
192191 pub async fn get_client ( & self ) -> RwLockReadGuard < ' _ , OidcClient > {
193- {
194- let client_lock = self . client . read ( ) . await ;
195- if client_lock. last_update . elapsed ( ) < OIDC_CLIENT_REFRESH_INTERVAL {
196- return RwLockReadGuard :: map ( client_lock, |ClientWithTime { client, .. } | client) ;
197- }
198- }
199- log:: debug!( "OIDC client is older than {OIDC_CLIENT_REFRESH_INTERVAL:?}, refreshing..." ) ;
200- self . refresh ( ) . await ;
201- let with_time = self . client . read ( ) . await ;
202- RwLockReadGuard :: map ( with_time, |ClientWithTime { client, .. } | client)
192+ RwLockReadGuard :: map (
193+ self . client . read ( ) . await ,
194+ |ClientWithTime { client, .. } | client,
195+ )
203196 }
204197
205198 /// Validate and decode the claims of an OIDC token, without refreshing the client.
@@ -208,8 +201,7 @@ impl OidcState {
208201 id_token : & OidcToken ,
209202 state : Option < & OidcLoginState > ,
210203 ) -> anyhow:: Result < OidcClaims > {
211- // Do not refresh the client on every check
212- let client = & self . client . read ( ) . await . client ;
204+ let client = & self . get_client ( ) . await ;
213205 let verifier = self . config . create_id_token_verifier ( client) ;
214206 let nonce_verifier = |nonce : Option < & Nonce > | check_nonce ( nonce, state) ;
215207 let claims: OidcClaims = id_token
@@ -234,6 +226,14 @@ pub async fn initialize_oidc_state(
234226 ) ) )
235227}
236228
229+ async fn build_oidc_client_from_appdata (
230+ cfg : & OidcConfig ,
231+ req : & ServiceRequest ,
232+ ) -> anyhow:: Result < OidcClient > {
233+ let http_client = get_http_client_from_appdata ( req) ?;
234+ build_oidc_client ( cfg, http_client) . await
235+ }
236+
237237async fn build_oidc_client (
238238 oidc_cfg : & OidcConfig ,
239239 http_client : & Client ,
@@ -319,6 +319,7 @@ async fn handle_request(
319319 request : ServiceRequest ,
320320) -> actix_web:: Result < MiddlewareResponse > {
321321 log:: trace!( "Started OIDC middleware request handling" ) ;
322+ oidc_state. refresh_if_expired ( & request) . await ;
322323 let response = match get_authenticated_user_info ( oidc_state, & request) . await {
323324 Ok ( Some ( claims) ) => {
324325 if request. path ( ) != SQLPAGE_REDIRECT_URI {
0 commit comments