diff --git a/crates/crates_io_trustpub/src/keystore/impl.rs b/crates/crates_io_trustpub/src/keystore/impl.rs index 8335ddd4d55..504909b4289 100644 --- a/crates/crates_io_trustpub/src/keystore/impl.rs +++ b/crates/crates_io_trustpub/src/keystore/impl.rs @@ -2,6 +2,7 @@ use super::OidcKeyStore; use super::load_jwks::load_jwks; use async_trait::async_trait; use jsonwebtoken::DecodingKey; +use jsonwebtoken::jwk::JwkSet; use std::collections::HashMap; use std::time::{Duration, Instant}; use tokio::sync::RwLock; @@ -24,6 +25,37 @@ struct Cache { last_update: Option, } +impl Cache { + /// Returns true if the cache was updated within the minimum refresh interval. + fn recently_updated(&self) -> bool { + const MIN_AGE_BEFORE_REFRESH: Duration = Duration::from_secs(60); + + self.last_update + .is_some_and(|last_update| last_update.elapsed() < MIN_AGE_BEFORE_REFRESH) + } + + /// Updates the key cache with a new JWK Set, replacing all existing keys. + /// + /// This method clears the current cache and populates it with decoding keys + /// from the provided JWK Set. Keys without a key ID are skipped with a warning. + /// The cache's last update timestamp is set to the current time. + fn update(&mut self, jwks: &JwkSet) -> anyhow::Result<()> { + self.keys.clear(); + for key in &jwks.keys { + if let Some(key_id) = &key.common.key_id { + let decoding_key = DecodingKey::from_jwk(key)?; + self.keys.insert(key_id.clone(), decoding_key); + } else { + warn!("OIDC key without a key ID found, skipping."); + } + } + + self.last_update = Some(Instant::now()); + + Ok(()) + } +} + impl RealOidcKeyStore { /// Creates a new instance of [`RealOidcKeyStore`]. pub fn new(issuer_uri: String) -> Self { @@ -44,8 +76,6 @@ impl RealOidcKeyStore { #[async_trait] impl OidcKeyStore for RealOidcKeyStore { async fn get_oidc_key(&self, key_id: &str) -> anyhow::Result> { - const MIN_AGE_BEFORE_REFRESH: Duration = Duration::from_secs(60); - // First, try to get the key with just a read lock. let cache = self.cache.read().await; if let Some(key) = cache.keys.get(key_id) { @@ -56,10 +86,7 @@ impl OidcKeyStore for RealOidcKeyStore { drop(cache); let mut cache = self.cache.write().await; - if cache - .last_update - .is_some_and(|last_update| last_update.elapsed() < MIN_AGE_BEFORE_REFRESH) - { + if cache.recently_updated() { // If we're in a cooldown from a previous refresh, return // whatever is in the cache, which will probably be None // given the previous check under the read lock. @@ -68,18 +95,7 @@ impl OidcKeyStore for RealOidcKeyStore { // Load the keys from the OIDC provider. let jwks = load_jwks(&self.client, &self.issuer_uri).await?; - - cache.keys.clear(); - for key in jwks.keys { - if let Some(key_id) = &key.common.key_id { - let decoding_key = DecodingKey::from_jwk(&key)?; - cache.keys.insert(key_id.clone(), decoding_key); - } else { - warn!("OIDC key without a key ID found, skipping."); - } - } - - cache.last_update = Some(Instant::now()); + cache.update(&jwks)?; Ok(cache.keys.get(key_id).cloned()) }