Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 77 additions & 19 deletions src/webserver/oidc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use std::collections::HashSet;
use std::future::ready;
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
use std::{
future::Future,
pin::Pin,
str::FromStr,
sync::{Arc, RwLock},
time::Duration,
};

use crate::webserver::http_client::get_http_client_from_appdata;
use crate::{app_config::AppConfig, AppState};
Expand Down Expand Up @@ -132,7 +138,7 @@ fn get_app_host(config: &AppConfig) -> String {

pub struct OidcState {
pub config: Arc<OidcConfig>,
pub client: Arc<OidcClient>,
pub client: Arc<RwLock<OidcClient>>,
}

pub async fn initialize_oidc_state(
Expand All @@ -149,10 +155,61 @@ pub async fn initialize_oidc_state(
discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?;
let client = make_oidc_client(&oidc_cfg, provider_metadata)?;

Ok(Some(Arc::new(OidcState {
config: oidc_cfg,
client: Arc::new(client),
})))
let oidc_state = Arc::new(OidcState {
config: oidc_cfg.clone(),
client: Arc::new(RwLock::new(client)),
});

// Start background refresh task
let refresh_state = Arc::clone(&oidc_state);
let refresh_config = Arc::clone(app_config);
tokio::spawn(async move {
refresh_oidc_metadata_periodically(refresh_state, refresh_config).await;
});

Ok(Some(oidc_state))
}

/// Background task that refreshes OIDC provider metadata every 6 hours
async fn refresh_oidc_metadata_periodically(
oidc_state: Arc<OidcState>,
app_config: Arc<AppConfig>,
) {
let mut interval = tokio::time::interval(Duration::from_secs(6 * 60 * 60)); // 6 hours
interval.tick().await; // Skip first tick (already initialized)

loop {
interval.tick().await;

log::info!("Refreshing OIDC provider metadata");

match refresh_oidc_client(&oidc_state, &app_config).await {
Ok(()) => {
log::info!("Successfully refreshed OIDC provider metadata");
}
Err(e) => {
log::warn!("Failed to refresh OIDC provider metadata: {}", e);
// Continue with existing client
}
}
}
}

/// Refresh the OIDC client with new provider metadata
async fn refresh_oidc_client(
oidc_state: &Arc<OidcState>,
app_config: &AppConfig,
) -> anyhow::Result<()> {
let http_client = make_http_client(app_config)?;
let provider_metadata =
discover_provider_metadata(&http_client, oidc_state.config.issuer_url.clone()).await?;
let new_client = make_oidc_client(&oidc_state.config, provider_metadata)?;

// Replace the client atomically
let mut client_guard = oidc_state.client.write().unwrap();
*client_guard = new_client;

Ok(())
}

pub struct OidcMiddleware {
Expand Down Expand Up @@ -239,29 +296,31 @@ where

log::debug!("Redirecting to OIDC provider");

let response = build_auth_provider_redirect_response(
&self.oidc_state.client,
&self.oidc_state.config,
&request,
);
let client = self.oidc_state.client.read().unwrap();
let response =
build_auth_provider_redirect_response(&*client, &self.oidc_state.config, &request);
Box::pin(async move { Ok(request.into_response(response)) })
}

fn handle_oidc_callback(
&self,
request: ServiceRequest,
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
let oidc_client = Arc::clone(&self.oidc_state.client);
let oidc_config = Arc::clone(&self.oidc_state.config);
let oidc_state = Arc::clone(&self.oidc_state);

Box::pin(async move {
let client = oidc_state.client.read().unwrap();
let query_string = request.query_string();
match process_oidc_callback(&oidc_client, &oidc_config, query_string, &request).await {
match process_oidc_callback(&*client, &oidc_state.config, query_string, &request).await
{
Ok(response) => Ok(request.into_response(response)),
Err(e) => {
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
let resp =
build_auth_provider_redirect_response(&oidc_client, &oidc_config, &request);
let resp = build_auth_provider_redirect_response(
&*client,
&oidc_state.config,
&request,
);
Ok(request.into_response(resp))
}
}
Expand Down Expand Up @@ -296,9 +355,8 @@ where
fn call(&self, request: ServiceRequest) -> Self::Future {
log::trace!("Started OIDC middleware request handling");

let oidc_client = Arc::clone(&self.oidc_state.client);
let oidc_config = Arc::clone(&self.oidc_state.config);
match get_authenticated_user_info(&oidc_client, &oidc_config, &request) {
let client = self.oidc_state.client.read().unwrap();
match get_authenticated_user_info(&*client, &self.oidc_state.config, &request) {
Ok(Some(claims)) => {
if request.path() == SQLPAGE_REDIRECT_URI {
return handle_authenticated_oidc_callback(request);
Expand Down
Loading