Skip to content

Commit 1667de9

Browse files
committed
Refactor OIDC client refresh mechanism
Moves client refresh logic into a separate method and uses request-scoped HTTP client instead of storing AppConfig. This simplifies the state struct and improves the refresh mechanism's reliability.
1 parent 4578719 commit 1667de9

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

src/webserver/oidc.rs

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ pub struct ClientWithTime {
150150

151151
pub struct OidcState {
152152
pub config: OidcConfig,
153-
app_config: AppConfig,
154153
client: RwLock<ClientWithTime>,
155154
}
156155

@@ -161,45 +160,39 @@ impl OidcState {
161160

162161
Ok(Self {
163162
config: oidc_cfg,
164-
app_config,
165163
client: RwLock::new(ClientWithTime {
166164
client,
167165
last_update: Instant::now(),
168166
}),
169167
})
170168
}
171169

172-
async fn refresh(&self) {
173-
let Ok(http_client) = make_http_client(&self.app_config) else {
174-
log::error!("Failed to create HTTP client");
175-
return;
176-
};
177-
let mut write_guard = self.client.write().await;
178-
match build_oidc_client(&self.config, &http_client).await {
179-
Ok(client) => {
180-
*write_guard = ClientWithTime {
181-
client,
170+
async fn refresh(&self, service_request: &ServiceRequest) {
171+
match build_oidc_client_from_appdata(&self.config, service_request).await {
172+
Ok(http_client) => {
173+
*self.client.write().await = ClientWithTime {
174+
client: http_client,
182175
last_update: Instant::now(),
183-
};
184-
}
185-
Err(e) => {
186-
log::error!("Failed to refresh OIDC client: {e}");
176+
}
187177
}
178+
Err(e) => log::error!("Failed to refresh OIDC client: {e}"),
179+
}
180+
}
181+
182+
/// Refreshes the OIDC client from the provider metadata URL if it has expired.
183+
/// Most providers update their signing keys periodically.
184+
pub async fn refresh_if_expired(&self, service_request: &ServiceRequest) {
185+
if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_REFRESH_INTERVAL {
186+
self.refresh(service_request).await;
188187
}
189188
}
190189

191190
/// Gets a reference to the oidc client, potentially generating a new one if needed
192191
pub async fn get_client(&self) -> RwLockReadGuard<'_, OidcClient> {
193-
{
194-
let client_lock = self.client.read().await;
195-
if client_lock.last_update.elapsed() < OIDC_CLIENT_REFRESH_INTERVAL {
196-
return RwLockReadGuard::map(client_lock, |ClientWithTime { client, .. }| client);
197-
}
198-
}
199-
log::debug!("OIDC client is older than {OIDC_CLIENT_REFRESH_INTERVAL:?}, refreshing...");
200-
self.refresh().await;
201-
let with_time = self.client.read().await;
202-
RwLockReadGuard::map(with_time, |ClientWithTime { client, .. }| client)
192+
RwLockReadGuard::map(
193+
self.client.read().await,
194+
|ClientWithTime { client, .. }| client,
195+
)
203196
}
204197

205198
/// Validate and decode the claims of an OIDC token, without refreshing the client.
@@ -208,8 +201,7 @@ impl OidcState {
208201
id_token: &OidcToken,
209202
state: Option<&OidcLoginState>,
210203
) -> anyhow::Result<OidcClaims> {
211-
// Do not refresh the client on every check
212-
let client = &self.client.read().await.client;
204+
let client = &self.get_client().await;
213205
let verifier = self.config.create_id_token_verifier(client);
214206
let nonce_verifier = |nonce: Option<&Nonce>| check_nonce(nonce, state);
215207
let claims: OidcClaims = id_token
@@ -234,6 +226,14 @@ pub async fn initialize_oidc_state(
234226
)))
235227
}
236228

229+
async fn build_oidc_client_from_appdata(
230+
cfg: &OidcConfig,
231+
req: &ServiceRequest,
232+
) -> anyhow::Result<OidcClient> {
233+
let http_client = get_http_client_from_appdata(req)?;
234+
build_oidc_client(cfg, http_client).await
235+
}
236+
237237
async fn build_oidc_client(
238238
oidc_cfg: &OidcConfig,
239239
http_client: &Client,
@@ -319,6 +319,7 @@ async fn handle_request(
319319
request: ServiceRequest,
320320
) -> actix_web::Result<MiddlewareResponse> {
321321
log::trace!("Started OIDC middleware request handling");
322+
oidc_state.refresh_if_expired(&request).await;
322323
let response = match get_authenticated_user_info(oidc_state, &request).await {
323324
Ok(Some(claims)) => {
324325
if request.path() != SQLPAGE_REDIRECT_URI {

0 commit comments

Comments
 (0)