diff --git a/src/signer.rs b/src/signer.rs index 2d411c65..6ab036d2 100644 --- a/src/signer.rs +++ b/src/signer.rs @@ -133,7 +133,7 @@ impl Signer { F: FnOnce(&str) -> T, { if self.is_expired() { - self.renew()?; + self.renew_if_expired()?; } let signature = self.signature.read(); @@ -175,7 +175,9 @@ impl Signer { )) } - fn renew(&self) -> Result<(), Error> { + fn renew_if_expired(&self) -> Result<(), Error> { + let mut signature = self.signature.write(); + let issued_at = get_time(); #[cfg(feature = "tracing")] @@ -189,12 +191,12 @@ impl Signer { ); } - let mut signature = self.signature.write(); - - *signature = Signature { - key: Self::create_signature(&self.secret, &self.key_id, &self.team_id, issued_at)?, - issued_at, - }; + if issued_at - signature.issued_at >= self.expire_after_s.as_secs() as i64 { + *signature = Signature { + key: Self::create_signature(&self.secret, &self.key_id, &self.team_id, issued_at)?, + issued_at, + }; + } Ok(()) } @@ -248,6 +250,8 @@ fn get_time() -> i64 { #[cfg(test)] mod tests { + use std::{collections::HashSet, sync::Mutex, time::Instant}; + use super::*; const PRIVATE_KEY: &str = "-----BEGIN PRIVATE KEY----- @@ -293,4 +297,50 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ assert_ne!(sig1, sig2); } + + #[test] + fn test_signature_caching_in_multithreads() { + let signer = Signer::new( + PRIVATE_KEY.as_bytes(), + "89AFRD1X22", + "ASDFQWERTY", + Duration::from_secs(3), + ) + .unwrap(); + + let signer = Arc::new(signer); + + let created_sign: Arc>> = Arc::new(Mutex::new(HashSet::new())); + let mut threads = Vec::new(); + let now = Instant::now(); + + for _ in 0..100 { + let created_sign = created_sign.clone(); + let now = now.clone(); + let signer = signer.clone(); + threads.push(std::thread::spawn(move || { + let mut sig = String::new(); + loop { + let mut sig1 = String::new(); + signer.with_signature(|sig| sig1.push_str(sig)).unwrap(); + + if sig1 != sig { + sig = sig1.clone(); + let mut created_sign = created_sign.lock().unwrap(); + created_sign.insert(sig1); + } + + if now.elapsed() > Duration::from_secs(4) { + break; + } + } + })); + } + + for th in threads { + let _ = th.join(); + } + + assert_eq!(created_sign.lock().unwrap().len(), 2); + } }