1616
1717package org .springframework .security .oauth2 .jwt ;
1818
19+ import com .nimbusds .jose .KeySourceException ;
20+ import com .nimbusds .jose .jwk .JWK ;
21+ import com .nimbusds .jose .jwk .JWKSelector ;
22+ import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
23+ import com .nimbusds .jose .jwk .source .JWKSetUnavailableException ;
24+ import com .nimbusds .jose .jwk .source .URLBasedJWKSetSource ;
1925import java .io .IOException ;
2026import java .net .MalformedURLException ;
2127import java .net .URL ;
2632import java .util .Collections ;
2733import java .util .HashSet ;
2834import java .util .LinkedHashMap ;
35+ import java .util .List ;
2936import java .util .Map ;
3037import java .util .Set ;
3138import java .util .function .Consumer ;
3542
3643import com .nimbusds .jose .JOSEException ;
3744import com .nimbusds .jose .JWSAlgorithm ;
38- import com .nimbusds .jose .RemoteKeySourceException ;
3945import com .nimbusds .jose .jwk .JWKSet ;
40- import com .nimbusds .jose .jwk .source .JWKSetCache ;
4146import com .nimbusds .jose .jwk .source .JWKSource ;
42- import com .nimbusds .jose .jwk .source .RemoteJWKSet ;
4347import com .nimbusds .jose .proc .JWSKeySelector ;
4448import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
4549import com .nimbusds .jose .proc .SecurityContext ;
@@ -165,7 +169,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
165169 .build ();
166170 // @formatter:on
167171 }
168- catch (RemoteKeySourceException ex ) {
172+ catch (JWKSetUnavailableException ex ) {
169173 this .logger .trace ("Failed to retrieve JWK set" , ex );
170174 if (ex .getCause () instanceof ParseException ) {
171175 throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -377,11 +381,12 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
377381 }
378382
379383 JWKSource <SecurityContext > jwkSource (ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
380- if (this .cache == null ) {
381- return new RemoteJWKSet <>(toURL (jwkSetUri ), jwkSetRetriever );
384+ URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource (toURL (jwkSetUri ), jwkSetRetriever );
385+ if (this .cache == null ) {
386+ return new SpringURLBasedJWKSource (urlBasedJWKSetSource );
382387 }
383- JWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
384- return new RemoteJWKSet <>(toURL ( jwkSetUri ), jwkSetRetriever , jwkSetCache );
388+ SpringJWKSetCache springJWKSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
389+ return new SpringURLBasedJWKSource <>(urlBasedJWKSetSource , springJWKSetCache );
385390 }
386391
387392 JWTProcessor <SecurityContext > processor () {
@@ -414,7 +419,50 @@ private static URL toURL(String url) {
414419 }
415420 }
416421
417- private static final class SpringJWKSetCache implements JWKSetCache {
422+ private static final class SpringURLBasedJWKSource <C extends SecurityContext > implements JWKSource <C > {
423+
424+ private final URLBasedJWKSetSource urlBasedJWKSetSource ;
425+
426+ private SpringJWKSetCache springJWKSetCache ;
427+
428+ private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource ) {
429+ this .urlBasedJWKSetSource = urlBasedJWKSetSource ;
430+ }
431+
432+ private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource ,
433+ SpringJWKSetCache springJWKSetCache ) {
434+ this .urlBasedJWKSetSource = urlBasedJWKSetSource ;
435+ this .springJWKSetCache = springJWKSetCache ;
436+ }
437+
438+ @ Override
439+ public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
440+ if (this .springJWKSetCache != null ) {
441+ synchronized (this ) {
442+ JWKSet jwkSet = this .springJWKSetCache .get ();
443+ if (this .springJWKSetCache .requiresRefresh () || jwkSet == null ) {
444+ jwkSet = fetchJWKSet (context );
445+ this .springJWKSetCache .put (jwkSet );
446+ }
447+ List <JWK > jwks = jwkSelector .select (jwkSet );
448+ if (!jwks .isEmpty ()) {
449+ return jwks ;
450+ }
451+ jwkSet = fetchJWKSet (context );
452+ this .springJWKSetCache .put (jwkSet );
453+ return jwkSelector .select (jwkSet );
454+ }
455+ }
456+ return jwkSelector .select (fetchJWKSet (context ));
457+ }
458+
459+ private JWKSet fetchJWKSet (SecurityContext context ) throws KeySourceException {
460+ return urlBasedJWKSetSource .getJWKSet (JWKSetCacheRefreshEvaluator .noRefresh (),
461+ System .currentTimeMillis (), context );
462+ }
463+ }
464+
465+ private static final class SpringJWKSetCache {
418466
419467 private final String jwkSetUri ;
420468
@@ -433,27 +481,22 @@ private void updateJwkSetFromCache() {
433481 if (cachedJwkSet != null ) {
434482 try {
435483 this .jwkSet = JWKSet .parse (cachedJwkSet );
436- }
437- catch (ParseException ignored ) {
484+ } catch (ParseException ignored ) {
438485 // Ignore invalid cache value
439486 }
440487 }
441488 }
442489
443- // Note: Only called from inside a synchronized block in RemoteJWKSet.
444- @ Override
490+ // Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
445491 public void put (JWKSet jwkSet ) {
446492 this .jwkSet = jwkSet ;
447493 this .cache .put (this .jwkSetUri , jwkSet .toString (false ));
448494 }
449495
450- @ Override
451496 public JWKSet get () {
452497 return (!requiresRefresh ()) ? this .jwkSet : null ;
453-
454498 }
455499
456- @ Override
457500 public boolean requiresRefresh () {
458501 return this .cache .get (this .jwkSetUri ) == null ;
459502 }
0 commit comments