11/*
2- * Copyright 2002-2024 the original author or authors.
2+ * Copyright 2002-2025 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2020import com .nimbusds .jose .jwk .JWK ;
2121import com .nimbusds .jose .jwk .JWKMatcher ;
2222import com .nimbusds .jose .jwk .JWKSelector ;
23- import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
24- import com .nimbusds .jose .jwk .source .URLBasedJWKSetSource ;
23+ import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24+ import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
2525import java .io .IOException ;
2626import java .net .MalformedURLException ;
2727import java .net .URL ;
3535import java .util .List ;
3636import java .util .Map ;
3737import java .util .Set ;
38+ import java .util .concurrent .locks .ReentrantLock ;
3839import java .util .function .Consumer ;
3940import java .util .function .Function ;
4041
4849import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
4950import com .nimbusds .jose .proc .SecurityContext ;
5051import com .nimbusds .jose .proc .SingleKeyJWSKeySelector ;
51- import com .nimbusds .jose .util .Resource ;
52- import com .nimbusds .jose .util .ResourceRetriever ;
5352import com .nimbusds .jwt .JWT ;
5453import com .nimbusds .jwt .JWTClaimsSet ;
5554import com .nimbusds .jwt .JWTParser ;
6160import org .apache .commons .logging .LogFactory ;
6261
6362import org .springframework .cache .Cache ;
63+ import org .springframework .cache .support .NoOpCache ;
6464import org .springframework .core .convert .converter .Converter ;
6565import org .springframework .http .HttpHeaders ;
6666import org .springframework .http .HttpMethod ;
@@ -278,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
278278
279279 private RestOperations restOperations = new RestTemplate ();
280280
281- private Cache cache ;
281+ private Cache cache = new NoOpCache ( "default" ) ;
282282
283283 private Consumer <ConfigurableJWTProcessor <SecurityContext >> jwtProcessorCustomizer ;
284284
@@ -381,19 +381,13 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
381381 return new JWSVerificationKeySelector <>(jwsAlgorithms , jwkSource );
382382 }
383383
384- JWKSource <SecurityContext > jwkSource (ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
385- URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource (toURL (jwkSetUri ), jwkSetRetriever );
386- if (this .cache == null ) {
387- return new SpringURLBasedJWKSource (urlBasedJWKSetSource );
388- }
389- SpringJWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
390- return new SpringURLBasedJWKSource <>(urlBasedJWKSetSource , jwkSetCache );
384+ JWKSource <SecurityContext > jwkSource () {
385+ String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386+ return new SpringJWKSource <>(this .restOperations , this .cache , toURL (jwkSetUri ), jwkSetUri );
391387 }
392388
393389 JWTProcessor <SecurityContext > processor () {
394- ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever (this .restOperations );
395- String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
396- JWKSource <SecurityContext > jwkSource = jwkSource (jwkSetRetriever , jwkSetUri );
390+ JWKSource <SecurityContext > jwkSource = jwkSource ();
397391 ConfigurableJWTProcessor <SecurityContext > jwtProcessor = new DefaultJWTProcessor <>();
398392 jwtProcessor .setJWSKeySelector (jwsKeySelector (jwkSource ));
399393 // Spring Security validates the claim set independent from Nimbus
@@ -420,153 +414,130 @@ private static URL toURL(String url) {
420414 }
421415 }
422416
423- private static final class SpringURLBasedJWKSource <C extends SecurityContext > implements JWKSource <C > {
417+ private static final class SpringJWKSource <C extends SecurityContext > implements JWKSource <C > {
424418
425- private final URLBasedJWKSetSource urlBasedJWKSetSource ;
419+ private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ( "application" , "jwk-set+json" ) ;
426420
427- private final SpringJWKSetCache jwkSetCache ;
421+ private final ReentrantLock reentrantLock = new ReentrantLock () ;
428422
429- private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource ) {
430- this (urlBasedJWKSetSource , null );
431- }
423+ private final RestOperations restOperations ;
424+
425+ private final Cache cache ;
426+
427+ private final URL url ;
428+
429+ private final String jwkSetUri ;
432430
433- private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource , SpringJWKSetCache jwkSetCache ) {
434- this .urlBasedJWKSetSource = urlBasedJWKSetSource ;
435- this .jwkSetCache = jwkSetCache ;
431+ private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
432+ Assert .notNull (restOperations , "restOperations cannot be null" );
433+ this .restOperations = restOperations ;
434+ this .cache = cache ;
435+ this .url = url ;
436+ this .jwkSetUri = jwkSetUri ;
436437 }
437438
439+
438440 @ Override
439441 public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
440- if (this .jwkSetCache != null ) {
441- JWKSet jwkSet = this .jwkSetCache .get ();
442- if (this .jwkSetCache .requiresRefresh () || jwkSet == null ) {
443- synchronized (this ) {
444- jwkSet = fetchJWKSet ();
445- this .jwkSetCache .put (jwkSet );
446- }
447- }
448- List <JWK > matches = jwkSelector .select (jwkSet );
449- if (!matches .isEmpty ()) {
450- return matches ;
451- }
452- String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
453- if (soughtKeyID == null ) {
454- return Collections .emptyList ();
455- }
456- if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
457- return Collections .emptyList ();
458- }
459- synchronized (this ) {
460- if (jwkSet == this .jwkSetCache .get ()) {
461- jwkSet = fetchJWKSet ();
462- this .jwkSetCache .put (jwkSet );
463- } else {
464- jwkSet = this .jwkSetCache .get ();
442+ String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443+ JWKSet jwkSet = null ;
444+ if (cachedJwkSet != null ) {
445+ jwkSet = parse (cachedJwkSet );
446+ }
447+ if (jwkSet == null ) {
448+ if (reentrantLock .tryLock ()) {
449+ try {
450+ String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451+ if (cachedJwkSetAfterLock != null ) {
452+ jwkSet = parse (cachedJwkSetAfterLock );
453+ }
454+ if (jwkSet == null ) {
455+ try {
456+ jwkSet = fetchJWKSet ();
457+ } catch (IOException e ) {
458+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459+ }
460+ }
461+ } finally {
462+ reentrantLock .unlock ();
465463 }
466464 }
467- if (jwkSet == null ) {
468- return Collections .emptyList ();
469- }
470- return jwkSelector .select (jwkSet );
471465 }
472- return jwkSelector .select (fetchJWKSet ());
473- }
474-
475- private JWKSet fetchJWKSet () throws KeySourceException {
476- return this .urlBasedJWKSetSource .getJWKSet (JWKSetCacheRefreshEvaluator .noRefresh (),
477- System .currentTimeMillis (), null );
478- }
479-
480- private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
481- Set <String > keyIDs = jwkMatcher .getKeyIDs ();
482-
483- if (keyIDs == null || keyIDs .isEmpty ()) {
484- return null ;
466+ List <JWK > matches = jwkSelector .select (jwkSet );
467+ if (!matches .isEmpty ()) {
468+ return matches ;
485469 }
486-
487- for (String id : keyIDs ) {
488- if (id != null ) {
489- return id ;
490- }
470+ String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471+ if (soughtKeyID == null ) {
472+ return Collections .emptyList ();
473+ }
474+ if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475+ return Collections .emptyList ();
491476 }
492- return null ;
493- }
494- }
495-
496- private static final class SpringJWKSetCache {
497-
498- private final String jwkSetUri ;
499-
500- private final Cache cache ;
501-
502- private JWKSet jwkSet ;
503-
504- SpringJWKSetCache (String jwkSetUri , Cache cache ) {
505- this .jwkSetUri = jwkSetUri ;
506- this .cache = cache ;
507- this .updateJwkSetFromCache ();
508- }
509477
510- private void updateJwkSetFromCache () {
511- String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
512- if (cachedJwkSet != null ) {
478+ if (reentrantLock .tryLock ()) {
513479 try {
514- this .jwkSet = JWKSet .parse (cachedJwkSet );
515- }
516- catch (ParseException ignored ) {
517- // Ignore invalid cache value
480+ String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481+ JWKSet cacheJwkSet = parse (jwkSetUri );
482+ if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483+ try {
484+ jwkSet = fetchJWKSet ();
485+ } catch (IOException e ) {
486+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487+ }
488+ } else if (jwkSetUri != null ) {
489+ jwkSet = parse (jwkSetUri );
490+ }
491+ } finally {
492+ reentrantLock .unlock ();
518493 }
519494 }
495+ if (jwkSet == null ) {
496+ return Collections .emptyList ();
497+ }
498+ return jwkSelector .select (jwkSet );
520499 }
521500
522- // Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
523- public void put (JWKSet jwkSet ) {
524- this .jwkSet = jwkSet ;
525- this .cache .put (this .jwkSetUri , jwkSet .toString (false ));
526- }
527-
528- public JWKSet get () {
529- return (!requiresRefresh ()) ? this .jwkSet : null ;
530- }
531-
532- public boolean requiresRefresh () {
533- return this .cache .get (this .jwkSetUri ) == null ;
534- }
535-
536- }
537-
538- private static class RestOperationsResourceRetriever implements ResourceRetriever {
539-
540- private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
541-
542- private final RestOperations restOperations ;
543-
544- RestOperationsResourceRetriever (RestOperations restOperations ) {
545- Assert .notNull (restOperations , "restOperations cannot be null" );
546- this .restOperations = restOperations ;
547- }
548-
549- @ Override
550- public Resource retrieveResource (URL url ) throws IOException {
501+ private JWKSet fetchJWKSet () throws IOException , KeySourceException {
551502 HttpHeaders headers = new HttpHeaders ();
552503 headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
553- ResponseEntity <String > response = getResponse (url , headers );
504+ ResponseEntity <String > response = getResponse (headers );
554505 if (response .getStatusCode ().value () != 200 ) {
555506 throw new IOException (response .toString ());
556507 }
557- return new Resource (response .getBody (), "UTF-8" );
508+ try {
509+ String jwkSet = response .getBody ();
510+ this .cache .put (this .jwkSetUri , jwkSet );
511+ return JWKSet .parse (jwkSet );
512+ } catch (ParseException e ) {
513+ throw new JWKSetParseException ("Unable to parse JWK set" , e );
514+ }
558515 }
559516
560- private ResponseEntity <String > getResponse (URL url , HttpHeaders headers ) throws IOException {
517+ private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
561518 try {
562- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , url .toURI ());
519+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this . url .toURI ());
563520 return this .restOperations .exchange (request , String .class );
564- }
565- catch (Exception ex ) {
521+ } catch (Exception ex ) {
566522 throw new IOException (ex );
567523 }
568524 }
569525
526+ private JWKSet parse (String cachedJwkSet ) {
527+ JWKSet jwkSet = null ;
528+ try {
529+ jwkSet = JWKSet .parse (cachedJwkSet );
530+ } catch (ParseException ignored ) {
531+ // Ignore invalid cache value
532+ }
533+ return jwkSet ;
534+ }
535+
536+ private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537+ Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538+ return (keyIDs == null || keyIDs .isEmpty ()) ?
539+ null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
540+ }
570541 }
571542
572543 }
0 commit comments