diff --git a/src/webserver/oidc.rs b/src/webserver/oidc.rs index aaec925e..ebf1e28f 100644 --- a/src/webserver/oidc.rs +++ b/src/webserver/oidc.rs @@ -1,6 +1,12 @@ use std::collections::HashSet; use std::future::ready; -use std::{future::Future, pin::Pin, str::FromStr, sync::Arc}; +use std::{ + future::Future, + pin::Pin, + str::FromStr, + sync::Arc, + time::{Duration, Instant}, +}; use crate::webserver::http_client::get_http_client_from_appdata; use crate::{app_config::AppConfig, AppState}; @@ -21,6 +27,7 @@ use openidconnect::{ TokenResponse, }; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; use super::http_client::make_http_client; @@ -44,6 +51,71 @@ type OidcToken = openidconnect::IdToken< pub type OidcClaims = openidconnect::IdTokenClaims; +// Cache configuration based on industry best practices +const PROVIDER_METADATA_CACHE_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours +const MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60); // 5 minutes (rate limiting) + +#[derive(Clone, Debug)] +struct CachedProvider { + client: OidcClient, + metadata: openidconnect::core::CoreProviderMetadata, + cached_at: Instant, + last_refresh_attempt: Instant, +} + +impl CachedProvider { + fn new(client: OidcClient, metadata: openidconnect::core::CoreProviderMetadata) -> Self { + let now = Instant::now(); + Self { + client, + metadata, + cached_at: now, + last_refresh_attempt: now, + } + } + + fn is_stale(&self) -> bool { + self.cached_at.elapsed() > PROVIDER_METADATA_CACHE_DURATION + } + + fn can_refresh(&self) -> bool { + self.last_refresh_attempt.elapsed() > MIN_REFRESH_INTERVAL + } + + fn update(&mut self, client: OidcClient, metadata: openidconnect::core::CoreProviderMetadata) { + self.client = client; + self.metadata = metadata; + self.cached_at = Instant::now(); + } + + fn mark_refresh_attempt(&mut self) { + self.last_refresh_attempt = Instant::now(); + } +} + +/// Given an audience, verify if it is trusted. The `client_id` is always trusted, independently of this function. +#[derive(Clone, Debug)] +pub struct AudienceVerifier(Option>); + +impl AudienceVerifier { + /// JWT audiences (aud claim) are always required to contain the `client_id`, but they can also contain additional audiences. + /// By default we allow any additional audience. + /// The user can restrict the allowed additional audiences by providing a list of trusted audiences. + fn new(additional_trusted_audiences: Option>) -> Self { + AudienceVerifier(additional_trusted_audiences.map(HashSet::from_iter)) + } + + /// Returns a function that given an audience, verifies if it is trusted. + fn as_fn(&self) -> impl Fn(&Audience) -> bool + '_ { + move |aud: &Audience| -> bool { + let Some(trusted_set) = &self.0 else { + return true; + }; + trusted_set.contains(aud.as_str()) + } + } +} + #[derive(Clone, Debug)] pub struct OidcConfig { pub issuer_url: IssuerUrl, @@ -132,7 +204,64 @@ fn get_app_host(config: &AppConfig) -> String { pub struct OidcState { pub config: Arc, - pub client: Arc, + cached_provider: Arc>, +} + +impl OidcState { + /// Get the current OIDC client without attempting refresh + pub async fn get_client(&self) -> OidcClient { + let cache = self.cached_provider.read().await; + if cache.is_stale() { + log::warn!( + "OIDC provider metadata is stale (age: {:?}). Consider using get_client_with_refresh() for automatic refresh.", + cache.age() + ); + } + cache.client.clone() + } + + /// Get the current OIDC client, refreshing if stale and possible + pub async fn get_client_with_refresh(&self, http_client: &awc::Client) -> OidcClient { + // Try to refresh if cache is stale and we haven't tried recently + { + let cache = self.cached_provider.read().await; + if cache.is_stale() && cache.can_refresh() { + // Release read lock before attempting refresh + drop(cache); + if let Err(e) = self.refresh_provider(http_client).await { + log::warn!("Failed to refresh OIDC provider: {}", e); + } + } + } + + self.cached_provider.read().await.client.clone() + } + + /// Refresh provider metadata and client from the OIDC provider + async fn refresh_provider(&self, http_client: &awc::Client) -> anyhow::Result<()> { + let mut cache = self.cached_provider.write().await; + + // Double-check we can refresh (another thread might have just done it) + if !cache.can_refresh() { + return Ok(()); + } + + cache.mark_refresh_attempt(); + + log::debug!( + "Refreshing OIDC provider metadata for {}", + self.config.issuer_url + ); + + let new_metadata = + discover_provider_metadata(http_client, self.config.issuer_url.clone()).await?; + let new_client = make_oidc_client(&self.config, new_metadata.clone())?; + + cache.update(new_client, new_metadata); + + log::debug!("Successfully refreshed OIDC provider"); + Ok(()) + } } pub async fn initialize_oidc_state( @@ -145,14 +274,18 @@ pub async fn initialize_oidc_state( }; let http_client = make_http_client(app_config)?; + + // Initial metadata discovery let provider_metadata = discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?; - let client = make_oidc_client(&oidc_cfg, provider_metadata)?; + let client = make_oidc_client(&oidc_cfg, provider_metadata.clone())?; - Ok(Some(Arc::new(OidcState { + let oidc_state = Arc::new(OidcState { config: oidc_cfg, - client: Arc::new(client), - }))) + cached_provider: Arc::new(RwLock::new(CachedProvider::new(client, provider_metadata))), + }); + + Ok(Some(oidc_state)) } pub struct OidcMiddleware { @@ -218,54 +351,78 @@ where oidc_state, } } +} - fn handle_unauthenticated_request( - &self, - request: ServiceRequest, - ) -> LocalBoxFuture, Error>> { - log::debug!("Handling unauthenticated request to {}", request.path()); - if request.path() == SQLPAGE_REDIRECT_URI { - log::debug!("The request is the OIDC callback"); - return self.handle_oidc_callback(request); - } - - if self.oidc_state.config.is_public_path(request.path()) { - log::debug!( - "The request path {} is not in a public path, skipping OIDC authentication", - request.path() - ); - return Box::pin(self.service.call(request)); - } +async fn handle_unauthenticated_request( + oidc_state: Arc, + request: ServiceRequest, + service: S, +) -> Result, Error> +where + S: Service, Error = Error>, +{ + log::debug!("Handling unauthenticated request to {}", request.path()); - log::debug!("Redirecting to OIDC provider"); + if request.path() == SQLPAGE_REDIRECT_URI { + log::debug!("The request is the OIDC callback"); + return handle_oidc_callback(oidc_state, request).await; + } - let response = build_auth_provider_redirect_response( - &self.oidc_state.client, - &self.oidc_state.config, - &request, + if oidc_state.config.is_public_path(request.path()) { + log::debug!( + "The request path {} is public, skipping OIDC authentication", + request.path() ); - Box::pin(async move { Ok(request.into_response(response)) }) + return service.call(request).await; } - fn handle_oidc_callback( - &self, - request: ServiceRequest, - ) -> LocalBoxFuture, Error>> { - let oidc_client = Arc::clone(&self.oidc_state.client); - let oidc_config = Arc::clone(&self.oidc_state.config); + log::debug!("Redirecting to OIDC provider"); + + // Get HTTP client from app data for potential cache refresh + let http_client = match get_http_client_from_appdata(&request) { + Ok(client) => client, + Err(e) => { + log::error!("Failed to get HTTP client from app data: {}", e); + // Fall back to cached client without refresh + let client = oidc_state.get_client().await; + let response = + build_auth_provider_redirect_response(&client, &oidc_state.config, &request); + return Ok(request.into_response(response)); + } + }; - Box::pin(async move { - let query_string = request.query_string(); - match process_oidc_callback(&oidc_client, &oidc_config, query_string, &request).await { - Ok(response) => Ok(request.into_response(response)), - Err(e) => { - log::error!("Failed to process OIDC callback with params {query_string}: {e}"); - let resp = - build_auth_provider_redirect_response(&oidc_client, &oidc_config, &request); - Ok(request.into_response(resp)) - } - } - }) + let client = oidc_state.get_client_with_refresh(http_client).await; + let response = build_auth_provider_redirect_response(&client, &oidc_state.config, &request); + Ok(request.into_response(response)) +} + +async fn handle_oidc_callback( + oidc_state: Arc, + request: ServiceRequest, +) -> Result, Error> { + // Get HTTP client from app data for potential cache refresh + let http_client = match get_http_client_from_appdata(&request) { + Ok(client) => client, + Err(e) => { + log::error!("Failed to get HTTP client from app data: {}", e); + // Fall back to cached client without refresh + let oidc_client = oidc_state.get_client().await; + let resp = + build_auth_provider_redirect_response(&oidc_client, &oidc_state.config, &request); + return Ok(request.into_response(resp)); + } + }; + + let oidc_client = oidc_state.get_client_with_refresh(http_client).await; + let query_string = request.query_string(); + match process_oidc_callback(&oidc_client, &oidc_state.config, query_string, &request).await { + Ok(response) => Ok(request.into_response(response)), + Err(e) => { + log::error!("Failed to process OIDC callback with params {query_string}: {e}"); + let resp = + build_auth_provider_redirect_response(&oidc_client, &oidc_state.config, &request); + Ok(request.into_response(resp)) + } } } @@ -284,7 +441,7 @@ fn handle_authenticated_oidc_callback( impl Service for OidcService where - S: Service, Error = Error>, + S: Service, Error = Error> + Clone, S::Future: 'static, { type Response = ServiceResponse; @@ -296,35 +453,77 @@ where fn call(&self, request: ServiceRequest) -> Self::Future { log::trace!("Started OIDC middleware request handling"); - let oidc_client = Arc::clone(&self.oidc_state.client); - let oidc_config = Arc::clone(&self.oidc_state.config); - match get_authenticated_user_info(&oidc_client, &oidc_config, &request) { - Ok(Some(claims)) => { - if request.path() == SQLPAGE_REDIRECT_URI { - return handle_authenticated_oidc_callback(request); + let oidc_state = Arc::clone(&self.oidc_state); + let service = self.service.clone(); + Box::pin(async move { + // Get HTTP client from app data for potential cache refresh + let http_client = match get_http_client_from_appdata(&request) { + Ok(client) => client, + Err(e) => { + log::error!("Failed to get HTTP client from app data: {}", e); + // Fall back to cached client without refresh + let oidc_client = oidc_state.get_client().await; + match get_authenticated_user_info(&oidc_client, &oidc_state.config, &request) { + Ok(Some(claims)) => { + if request.path() == SQLPAGE_REDIRECT_URI { + return handle_authenticated_oidc_callback(request); + } + log::trace!( + "Storing authenticated user info in request extensions: {claims:?}" + ); + request.extensions_mut().insert(claims); + let future = service.call(request); + return future.await; + } + Ok(None) => { + log::trace!("No authenticated user found"); + return handle_unauthenticated_request(oidc_state, request, service) + .await; + } + Err(e) => { + log::debug!( + "{:?}", + e.context( + "An auth cookie is present but could not be verified. \ + Redirecting to OIDC provider to re-authenticate." + ) + ); + return handle_unauthenticated_request(oidc_state, request, service) + .await; + } + } + } + }; + + let oidc_client = oidc_state.get_client_with_refresh(http_client).await; + match get_authenticated_user_info(&oidc_client, &oidc_state.config, &request) { + Ok(Some(claims)) => { + if request.path() == SQLPAGE_REDIRECT_URI { + return handle_authenticated_oidc_callback(request); + } + log::trace!( + "Storing authenticated user info in request extensions: {claims:?}" + ); + request.extensions_mut().insert(claims); + let future = service.call(request); + let response = future.await?; + Ok(response) + } + Ok(None) => { + log::trace!("No authenticated user found"); + handle_unauthenticated_request(oidc_state, request, service).await + } + Err(e) => { + log::debug!( + "{:?}", + e.context( + "An auth cookie is present but could not be verified. \ + Redirecting to OIDC provider to re-authenticate." + ) + ); + handle_unauthenticated_request(oidc_state, request, service).await } - log::trace!("Storing authenticated user info in request extensions: {claims:?}"); - request.extensions_mut().insert(claims); - } - Ok(None) => { - log::trace!("No authenticated user found"); - return self.handle_unauthenticated_request(request); - } - Err(e) => { - log::debug!( - "{:?}", - e.context( - "An auth cookie is present but could not be verified. \ - Redirecting to OIDC provider to re-authenticate." - ) - ); - return self.handle_unauthenticated_request(request); } - } - let future = self.service.call(request); - Box::pin(async move { - let response = future.await?; - Ok(response) }) } } @@ -705,29 +904,6 @@ fn get_state_from_cookie(request: &ServiceRequest) -> anyhow::Result>); - -impl AudienceVerifier { - /// JWT audiences (aud claim) are always required to contain the `client_id`, but they can also contain additional audiences. - /// By default we allow any additional audience. - /// The user can restrict the allowed additional audiences by providing a list of trusted audiences. - fn new(additional_trusted_audiences: Option>) -> Self { - AudienceVerifier(additional_trusted_audiences.map(HashSet::from_iter)) - } - - /// Returns a function that given an audience, verifies if it is trusted. - fn as_fn(&self) -> impl Fn(&Audience) -> bool + '_ { - move |aud: &Audience| -> bool { - let Some(trusted_set) = &self.0 else { - return true; - }; - trusted_set.contains(aud.as_str()) - } - } -} - /// Validate that a redirect URL is safe to use (prevents open redirect attacks) fn validate_redirect_url(url: String) -> String { if url.starts_with('/') && !url.starts_with("//") {