11/*
2- * Copyright 2002-2023 the original author or authors.
2+ * Copyright 2002-2025 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1616
1717package org .springframework .security .oauth2 .jwt ;
1818
19+ import com .nimbusds .jose .KeySourceException ;
20+ import com .nimbusds .jose .jwk .JWK ;
21+ import com .nimbusds .jose .jwk .JWKMatcher ;
22+ import com .nimbusds .jose .jwk .JWKSelector ;
23+ import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24+ import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
1925import java .io .IOException ;
2026import java .net .MalformedURLException ;
2127import java .net .URL ;
2632import java .util .Collections ;
2733import java .util .HashSet ;
2834import java .util .LinkedHashMap ;
35+ import java .util .List ;
2936import java .util .Map ;
3037import java .util .Set ;
38+ import java .util .concurrent .locks .ReentrantLock ;
3139import java .util .function .Consumer ;
3240import java .util .function .Function ;
3341
3442import javax .crypto .SecretKey ;
3543
3644import com .nimbusds .jose .JOSEException ;
3745import com .nimbusds .jose .JWSAlgorithm ;
38- import com .nimbusds .jose .RemoteKeySourceException ;
3946import com .nimbusds .jose .jwk .JWKSet ;
40- import com .nimbusds .jose .jwk .source .JWKSetCache ;
4147import com .nimbusds .jose .jwk .source .JWKSource ;
42- import com .nimbusds .jose .jwk .source .RemoteJWKSet ;
4348import com .nimbusds .jose .proc .JWSKeySelector ;
4449import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
4550import com .nimbusds .jose .proc .SecurityContext ;
4651import com .nimbusds .jose .proc .SingleKeyJWSKeySelector ;
47- import com .nimbusds .jose .util .Resource ;
48- import com .nimbusds .jose .util .ResourceRetriever ;
4952import com .nimbusds .jwt .JWT ;
5053import com .nimbusds .jwt .JWTClaimsSet ;
5154import com .nimbusds .jwt .JWTParser ;
5760import org .apache .commons .logging .LogFactory ;
5861
5962import org .springframework .cache .Cache ;
63+ import org .springframework .cache .support .NoOpCache ;
6064import org .springframework .core .convert .converter .Converter ;
6165import org .springframework .http .HttpHeaders ;
6266import org .springframework .http .HttpMethod ;
8084 * @author Josh Cummings
8185 * @author Joe Grandja
8286 * @author Mykyta Bezverkhyi
87+ * @author Daeho Kwon
8388 * @since 5.2
8489 */
8590public final class NimbusJwtDecoder implements JwtDecoder {
@@ -165,7 +170,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
165170 .build ();
166171 // @formatter:on
167172 }
168- catch (RemoteKeySourceException ex ) {
173+ catch (KeySourceException ex ) {
169174 this .logger .trace ("Failed to retrieve JWK set" , ex );
170175 if (ex .getCause () instanceof ParseException ) {
171176 throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -273,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
273278
274279 private RestOperations restOperations = new RestTemplate ();
275280
276- private Cache cache ;
281+ private Cache cache = new NoOpCache ( "default" ) ;
277282
278283 private Consumer <ConfigurableJWTProcessor <SecurityContext >> jwtProcessorCustomizer ;
279284
@@ -376,18 +381,13 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
376381 return new JWSVerificationKeySelector <>(jwsAlgorithms , jwkSource );
377382 }
378383
379- JWKSource <SecurityContext > jwkSource (ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
380- if (this .cache == null ) {
381- return new RemoteJWKSet <>(toURL (jwkSetUri ), jwkSetRetriever );
382- }
383- JWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
384- return new RemoteJWKSet <>(toURL (jwkSetUri ), jwkSetRetriever , jwkSetCache );
384+ JWKSource <SecurityContext > jwkSource () {
385+ String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386+ return new SpringJWKSource <>(this .restOperations , this .cache , toURL (jwkSetUri ), jwkSetUri );
385387 }
386388
387389 JWTProcessor <SecurityContext > processor () {
388- ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever (this .restOperations );
389- String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
390- JWKSource <SecurityContext > jwkSource = jwkSource (jwkSetRetriever , jwkSetUri );
390+ JWKSource <SecurityContext > jwkSource = jwkSource ();
391391 ConfigurableJWTProcessor <SecurityContext > jwtProcessor = new DefaultJWTProcessor <>();
392392 jwtProcessor .setJWSKeySelector (jwsKeySelector (jwkSource ));
393393 // Spring Security validates the claim set independent from Nimbus
@@ -414,84 +414,130 @@ private static URL toURL(String url) {
414414 }
415415 }
416416
417- private static final class SpringJWKSetCache implements JWKSetCache {
417+ private static final class SpringJWKSource < C extends SecurityContext > implements JWKSource < C > {
418418
419- private final String jwkSetUri ;
419+ private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
420+
421+ private final ReentrantLock reentrantLock = new ReentrantLock ();
422+
423+ private final RestOperations restOperations ;
420424
421425 private final Cache cache ;
422426
423- private JWKSet jwkSet ;
427+ private final URL url ;
424428
425- SpringJWKSetCache (String jwkSetUri , Cache cache ) {
426- this .jwkSetUri = jwkSetUri ;
429+ private final String jwkSetUri ;
430+
431+ private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
432+ Assert .notNull (restOperations , "restOperations cannot be null" );
433+ this .restOperations = restOperations ;
427434 this .cache = cache ;
428- this .updateJwkSetFromCache ();
435+ this .url = url ;
436+ this .jwkSetUri = jwkSetUri ;
429437 }
430438
431- private void updateJwkSetFromCache () {
439+
440+ @ Override
441+ public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
432442 String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443+ JWKSet jwkSet = null ;
433444 if (cachedJwkSet != null ) {
434- try {
435- this .jwkSet = JWKSet .parse (cachedJwkSet );
436- }
437- catch (ParseException ignored ) {
438- // Ignore invalid cache value
445+ jwkSet = parse (cachedJwkSet );
446+ }
447+ if (jwkSet == null ) {
448+ if (reentrantLock .tryLock ()) {
449+ try {
450+ String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451+ if (cachedJwkSetAfterLock != null ) {
452+ jwkSet = parse (cachedJwkSetAfterLock );
453+ }
454+ if (jwkSet == null ) {
455+ try {
456+ jwkSet = fetchJWKSet ();
457+ } catch (IOException e ) {
458+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459+ }
460+ }
461+ } finally {
462+ reentrantLock .unlock ();
463+ }
439464 }
440465 }
441- }
442-
443- // Note: Only called from inside a synchronized block in RemoteJWKSet.
444- @ Override
445- public void put (JWKSet jwkSet ) {
446- this .jwkSet = jwkSet ;
447- this .cache .put (this .jwkSetUri , jwkSet .toString (false ));
448- }
449-
450- @ Override
451- public JWKSet get () {
452- return (!requiresRefresh ()) ? this .jwkSet : null ;
453-
454- }
455-
456- @ Override
457- public boolean requiresRefresh () {
458- return this .cache .get (this .jwkSetUri ) == null ;
459- }
460-
461- }
462-
463- private static class RestOperationsResourceRetriever implements ResourceRetriever {
464-
465- private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
466-
467- private final RestOperations restOperations ;
466+ List <JWK > matches = jwkSelector .select (jwkSet );
467+ if (!matches .isEmpty ()) {
468+ return matches ;
469+ }
470+ String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471+ if (soughtKeyID == null ) {
472+ return Collections .emptyList ();
473+ }
474+ if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475+ return Collections .emptyList ();
476+ }
468477
469- RestOperationsResourceRetriever (RestOperations restOperations ) {
470- Assert .notNull (restOperations , "restOperations cannot be null" );
471- this .restOperations = restOperations ;
478+ if (reentrantLock .tryLock ()) {
479+ try {
480+ String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481+ JWKSet cacheJwkSet = parse (jwkSetUri );
482+ if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483+ try {
484+ jwkSet = fetchJWKSet ();
485+ } catch (IOException e ) {
486+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487+ }
488+ } else if (jwkSetUri != null ) {
489+ jwkSet = parse (jwkSetUri );
490+ }
491+ } finally {
492+ reentrantLock .unlock ();
493+ }
494+ }
495+ if (jwkSet == null ) {
496+ return Collections .emptyList ();
497+ }
498+ return jwkSelector .select (jwkSet );
472499 }
473500
474- @ Override
475- public Resource retrieveResource (URL url ) throws IOException {
501+ private JWKSet fetchJWKSet () throws IOException , KeySourceException {
476502 HttpHeaders headers = new HttpHeaders ();
477503 headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
478- ResponseEntity <String > response = getResponse (url , headers );
504+ ResponseEntity <String > response = getResponse (headers );
479505 if (response .getStatusCode ().value () != 200 ) {
480506 throw new IOException (response .toString ());
481507 }
482- return new Resource (response .getBody (), "UTF-8" );
508+ try {
509+ String jwkSet = response .getBody ();
510+ this .cache .put (this .jwkSetUri , jwkSet );
511+ return JWKSet .parse (jwkSet );
512+ } catch (ParseException e ) {
513+ throw new JWKSetParseException ("Unable to parse JWK set" , e );
514+ }
483515 }
484516
485- private ResponseEntity <String > getResponse (URL url , HttpHeaders headers ) throws IOException {
517+ private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
486518 try {
487- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , url .toURI ());
519+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this . url .toURI ());
488520 return this .restOperations .exchange (request , String .class );
489- }
490- catch (Exception ex ) {
521+ } catch (Exception ex ) {
491522 throw new IOException (ex );
492523 }
493524 }
494525
526+ private JWKSet parse (String cachedJwkSet ) {
527+ JWKSet jwkSet = null ;
528+ try {
529+ jwkSet = JWKSet .parse (cachedJwkSet );
530+ } catch (ParseException ignored ) {
531+ // Ignore invalid cache value
532+ }
533+ return jwkSet ;
534+ }
535+
536+ private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537+ Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538+ return (keyIDs == null || keyIDs .isEmpty ()) ?
539+ null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
540+ }
495541 }
496542
497543 }
0 commit comments