11/*
2- * Copyright 2002-2023 the original author or authors.
2+ * Copyright 2002-2024 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1616
1717package org .springframework .security .oauth2 .jwt ;
1818
19+ import com .nimbusds .jose .KeySourceException ;
20+ import com .nimbusds .jose .jwk .JWK ;
21+ import com .nimbusds .jose .jwk .JWKSelector ;
22+ import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
23+ import com .nimbusds .jose .jwk .source .URLBasedJWKSetSource ;
1924import java .io .IOException ;
2025import java .net .MalformedURLException ;
2126import java .net .URL ;
2631import java .util .Collections ;
2732import java .util .HashSet ;
2833import java .util .LinkedHashMap ;
34+ import java .util .List ;
2935import java .util .Map ;
3036import java .util .Set ;
3137import java .util .function .Consumer ;
3541
3642import com .nimbusds .jose .JOSEException ;
3743import com .nimbusds .jose .JWSAlgorithm ;
38- import com .nimbusds .jose .RemoteKeySourceException ;
3944import com .nimbusds .jose .jwk .JWKSet ;
40- import com .nimbusds .jose .jwk .source .JWKSetCache ;
4145import com .nimbusds .jose .jwk .source .JWKSource ;
42- import com .nimbusds .jose .jwk .source .RemoteJWKSet ;
4346import com .nimbusds .jose .proc .JWSKeySelector ;
4447import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
4548import com .nimbusds .jose .proc .SecurityContext ;
8083 * @author Josh Cummings
8184 * @author Joe Grandja
8285 * @author Mykyta Bezverkhyi
86+ * @author Daeho Kwon
8387 * @since 5.2
8488 */
8589public final class NimbusJwtDecoder implements JwtDecoder {
@@ -165,7 +169,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
165169 .build ();
166170 // @formatter:on
167171 }
168- catch (RemoteKeySourceException ex ) {
172+ catch (KeySourceException ex ) {
169173 this .logger .trace ("Failed to retrieve JWK set" , ex );
170174 if (ex .getCause () instanceof ParseException ) {
171175 throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -377,11 +381,12 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
377381 }
378382
379383 JWKSource <SecurityContext > jwkSource (ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
380- if (this .cache == null ) {
381- return new RemoteJWKSet <>(toURL (jwkSetUri ), jwkSetRetriever );
384+ URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource (toURL (jwkSetUri ), jwkSetRetriever );
385+ if (this .cache == null ) {
386+ return new SpringURLBasedJWKSource (urlBasedJWKSetSource );
382387 }
383- JWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
384- return new RemoteJWKSet <>(toURL ( jwkSetUri ), jwkSetRetriever , jwkSetCache );
388+ SpringJWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
389+ return new SpringURLBasedJWKSource <>(urlBasedJWKSetSource , jwkSetCache );
385390 }
386391
387392 JWTProcessor <SecurityContext > processor () {
@@ -414,7 +419,49 @@ private static URL toURL(String url) {
414419 }
415420 }
416421
417- private static final class SpringJWKSetCache implements JWKSetCache {
422+ private static final class SpringURLBasedJWKSource <C extends SecurityContext > implements JWKSource <C > {
423+
424+ private final URLBasedJWKSetSource urlBasedJWKSetSource ;
425+
426+ private final SpringJWKSetCache jwkSetCache ;
427+
428+ private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource ) {
429+ this (urlBasedJWKSetSource , null );
430+ }
431+
432+ private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource , SpringJWKSetCache jwkSetCache ) {
433+ this .urlBasedJWKSetSource = urlBasedJWKSetSource ;
434+ this .jwkSetCache = jwkSetCache ;
435+ }
436+
437+ @ Override
438+ public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
439+ if (this .jwkSetCache != null ) {
440+ synchronized (this ) {
441+ JWKSet jwkSet = this .jwkSetCache .get ();
442+ if (this .jwkSetCache .requiresRefresh () || jwkSet == null ) {
443+ jwkSet = fetchJWKSet (context );
444+ this .jwkSetCache .put (jwkSet );
445+ }
446+ List <JWK > jwks = jwkSelector .select (jwkSet );
447+ if (!jwks .isEmpty ()) {
448+ return jwks ;
449+ }
450+ jwkSet = fetchJWKSet (context );
451+ this .jwkSetCache .put (jwkSet );
452+ return jwkSelector .select (jwkSet );
453+ }
454+ }
455+ return jwkSelector .select (fetchJWKSet (context ));
456+ }
457+
458+ private JWKSet fetchJWKSet (SecurityContext context ) throws KeySourceException {
459+ return this .urlBasedJWKSetSource .getJWKSet (JWKSetCacheRefreshEvaluator .noRefresh (),
460+ System .currentTimeMillis (), context );
461+ }
462+ }
463+
464+ private static final class SpringJWKSetCache {
418465
419466 private final String jwkSetUri ;
420467
@@ -440,20 +487,16 @@ private void updateJwkSetFromCache() {
440487 }
441488 }
442489
443- // Note: Only called from inside a synchronized block in RemoteJWKSet.
444- @ Override
490+ // Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
445491 public void put (JWKSet jwkSet ) {
446492 this .jwkSet = jwkSet ;
447493 this .cache .put (this .jwkSetUri , jwkSet .toString (false ));
448494 }
449495
450- @ Override
451496 public JWKSet get () {
452497 return (!requiresRefresh ()) ? this .jwkSet : null ;
453-
454498 }
455499
456- @ Override
457500 public boolean requiresRefresh () {
458501 return this .cache .get (this .jwkSetUri ) == null ;
459502 }
0 commit comments