Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions crates/crates_io_trustpub/src/keystore/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::OidcKeyStore;
use super::load_jwks::load_jwks;
use async_trait::async_trait;
use jsonwebtoken::DecodingKey;
use jsonwebtoken::jwk::JwkSet;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
Expand All @@ -24,6 +25,37 @@ struct Cache {
last_update: Option<Instant>,
}

impl Cache {
/// Returns true if the cache was updated within the minimum refresh interval.
fn recently_updated(&self) -> bool {
const MIN_AGE_BEFORE_REFRESH: Duration = Duration::from_secs(60);

self.last_update
.is_some_and(|last_update| last_update.elapsed() < MIN_AGE_BEFORE_REFRESH)
}

/// Updates the key cache with a new JWK Set, replacing all existing keys.
///
/// This method clears the current cache and populates it with decoding keys
/// from the provided JWK Set. Keys without a key ID are skipped with a warning.
/// The cache's last update timestamp is set to the current time.
fn update(&mut self, jwks: &JwkSet) -> anyhow::Result<()> {
self.keys.clear();
for key in &jwks.keys {
if let Some(key_id) = &key.common.key_id {
let decoding_key = DecodingKey::from_jwk(key)?;
self.keys.insert(key_id.clone(), decoding_key);
} else {
warn!("OIDC key without a key ID found, skipping.");
}
}

self.last_update = Some(Instant::now());

Ok(())
}
}

impl RealOidcKeyStore {
/// Creates a new instance of [`RealOidcKeyStore`].
pub fn new(issuer_uri: String) -> Self {
Expand All @@ -44,8 +76,6 @@ impl RealOidcKeyStore {
#[async_trait]
impl OidcKeyStore for RealOidcKeyStore {
async fn get_oidc_key(&self, key_id: &str) -> anyhow::Result<Option<DecodingKey>> {
const MIN_AGE_BEFORE_REFRESH: Duration = Duration::from_secs(60);

// First, try to get the key with just a read lock.
let cache = self.cache.read().await;
if let Some(key) = cache.keys.get(key_id) {
Expand All @@ -56,10 +86,7 @@ impl OidcKeyStore for RealOidcKeyStore {
drop(cache);

let mut cache = self.cache.write().await;
if cache
.last_update
.is_some_and(|last_update| last_update.elapsed() < MIN_AGE_BEFORE_REFRESH)
{
if cache.recently_updated() {
// If we're in a cooldown from a previous refresh, return
// whatever is in the cache, which will probably be None
// given the previous check under the read lock.
Expand All @@ -68,18 +95,7 @@ impl OidcKeyStore for RealOidcKeyStore {

// Load the keys from the OIDC provider.
let jwks = load_jwks(&self.client, &self.issuer_uri).await?;

cache.keys.clear();
for key in jwks.keys {
if let Some(key_id) = &key.common.key_id {
let decoding_key = DecodingKey::from_jwk(&key)?;
cache.keys.insert(key_id.clone(), decoding_key);
} else {
warn!("OIDC key without a key ID found, skipping.");
}
}

cache.last_update = Some(Instant::now());
cache.update(&jwks)?;

Ok(cache.keys.get(key_id).cloned())
}
Expand Down