diff --git a/src/webserver/oidc.rs b/src/webserver/oidc.rs index aaec925e..bf5c2c82 100644 --- a/src/webserver/oidc.rs +++ b/src/webserver/oidc.rs @@ -1,6 +1,12 @@ use std::collections::HashSet; use std::future::ready; -use std::{future::Future, pin::Pin, str::FromStr, sync::Arc}; +use std::{ + future::Future, + pin::Pin, + str::FromStr, + sync::{Arc, RwLock}, + time::Duration, +}; use crate::webserver::http_client::get_http_client_from_appdata; use crate::{app_config::AppConfig, AppState}; @@ -132,7 +138,7 @@ fn get_app_host(config: &AppConfig) -> String { pub struct OidcState { pub config: Arc, - pub client: Arc, + pub client: Arc>, } pub async fn initialize_oidc_state( @@ -149,10 +155,61 @@ pub async fn initialize_oidc_state( discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?; let client = make_oidc_client(&oidc_cfg, provider_metadata)?; - Ok(Some(Arc::new(OidcState { - config: oidc_cfg, - client: Arc::new(client), - }))) + let oidc_state = Arc::new(OidcState { + config: oidc_cfg.clone(), + client: Arc::new(RwLock::new(client)), + }); + + // Start background refresh task + let refresh_state = Arc::clone(&oidc_state); + let refresh_config = Arc::clone(app_config); + tokio::spawn(async move { + refresh_oidc_metadata_periodically(refresh_state, refresh_config).await; + }); + + Ok(Some(oidc_state)) +} + +/// Background task that refreshes OIDC provider metadata every 6 hours +async fn refresh_oidc_metadata_periodically( + oidc_state: Arc, + app_config: Arc, +) { + let mut interval = tokio::time::interval(Duration::from_secs(6 * 60 * 60)); // 6 hours + interval.tick().await; // Skip first tick (already initialized) + + loop { + interval.tick().await; + + log::info!("Refreshing OIDC provider metadata"); + + match refresh_oidc_client(&oidc_state, &app_config).await { + Ok(()) => { + log::info!("Successfully refreshed OIDC provider metadata"); + } + Err(e) => { + log::warn!("Failed to refresh OIDC provider metadata: {}", e); + // Continue with existing client + } + } + } +} + +/// Refresh the OIDC client with new provider metadata +async fn refresh_oidc_client( + oidc_state: &Arc, + app_config: &AppConfig, +) -> anyhow::Result<()> { + let http_client = make_http_client(app_config)?; + let provider_metadata = + discover_provider_metadata(&http_client, oidc_state.config.issuer_url.clone()).await?; + let new_client = make_oidc_client(&oidc_state.config, provider_metadata)?; + + // Replace the client atomically + let mut client_guard = oidc_state.client.write().unwrap(); + *client_guard = new_client; + + Ok(()) } pub struct OidcMiddleware { @@ -239,11 +296,9 @@ where log::debug!("Redirecting to OIDC provider"); - let response = build_auth_provider_redirect_response( - &self.oidc_state.client, - &self.oidc_state.config, - &request, - ); + let client = self.oidc_state.client.read().unwrap(); + let response = + build_auth_provider_redirect_response(&*client, &self.oidc_state.config, &request); Box::pin(async move { Ok(request.into_response(response)) }) } @@ -251,17 +306,21 @@ where &self, request: ServiceRequest, ) -> LocalBoxFuture, Error>> { - let oidc_client = Arc::clone(&self.oidc_state.client); - let oidc_config = Arc::clone(&self.oidc_state.config); + let oidc_state = Arc::clone(&self.oidc_state); Box::pin(async move { + let client = oidc_state.client.read().unwrap(); let query_string = request.query_string(); - match process_oidc_callback(&oidc_client, &oidc_config, query_string, &request).await { + match process_oidc_callback(&*client, &oidc_state.config, query_string, &request).await + { Ok(response) => Ok(request.into_response(response)), Err(e) => { log::error!("Failed to process OIDC callback with params {query_string}: {e}"); - let resp = - build_auth_provider_redirect_response(&oidc_client, &oidc_config, &request); + let resp = build_auth_provider_redirect_response( + &*client, + &oidc_state.config, + &request, + ); Ok(request.into_response(resp)) } } @@ -296,9 +355,8 @@ where fn call(&self, request: ServiceRequest) -> Self::Future { log::trace!("Started OIDC middleware request handling"); - let oidc_client = Arc::clone(&self.oidc_state.client); - let oidc_config = Arc::clone(&self.oidc_state.config); - match get_authenticated_user_info(&oidc_client, &oidc_config, &request) { + let client = self.oidc_state.client.read().unwrap(); + match get_authenticated_user_info(&*client, &self.oidc_state.config, &request) { Ok(Some(claims)) => { if request.path() == SQLPAGE_REDIRECT_URI { return handle_authenticated_oidc_callback(request);