1818
1919import com .nimbusds .jose .KeySourceException ;
2020import com .nimbusds .jose .jwk .JWK ;
21+ import com .nimbusds .jose .jwk .JWKMatcher ;
2122import com .nimbusds .jose .jwk .JWKSelector ;
2223import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
2324import com .nimbusds .jose .jwk .source .URLBasedJWKSetSource ;
@@ -437,20 +438,32 @@ private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource, Sprin
437438 @ Override
438439 public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
439440 if (this .jwkSetCache != null ) {
440- synchronized ( this ) {
441- JWKSet jwkSet = this .jwkSetCache .get ();
442- if (this . jwkSetCache . requiresRefresh () || jwkSet == null ) {
441+ JWKSet jwkSet = this . jwkSetCache . get ();
442+ if ( this .jwkSetCache .requiresRefresh () || jwkSet == null ) {
443+ synchronized (this ) {
443444 jwkSet = fetchJWKSet (context );
444445 this .jwkSetCache .put (jwkSet );
445446 }
446- List <JWK > jwks = jwkSelector .select (jwkSet );
447- if (!jwks .isEmpty ()) {
448- return jwks ;
449- }
447+ }
448+ List <JWK > matches = jwkSelector .select (jwkSet );
449+ if (!matches .isEmpty ()) {
450+ return matches ;
451+ }
452+ String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
453+ if (soughtKeyID == null ) {
454+ return Collections .emptyList ();
455+ }
456+ if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
457+ return Collections .emptyList ();
458+ }
459+ synchronized (this ) {
450460 jwkSet = fetchJWKSet (context );
451461 this .jwkSetCache .put (jwkSet );
452- return jwkSelector .select (jwkSet );
453462 }
463+ if (jwkSet == null ) {
464+ return Collections .emptyList ();
465+ }
466+ return jwkSelector .select (jwkSet );
454467 }
455468 return jwkSelector .select (fetchJWKSet (context ));
456469 }
@@ -459,6 +472,21 @@ private JWKSet fetchJWKSet(SecurityContext context) throws KeySourceException {
459472 return this .urlBasedJWKSetSource .getJWKSet (JWKSetCacheRefreshEvaluator .noRefresh (),
460473 System .currentTimeMillis (), context );
461474 }
475+
476+ private static String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
477+ Set <String > keyIDs = jwkMatcher .getKeyIDs ();
478+
479+ if (keyIDs == null || keyIDs .isEmpty ()) {
480+ return null ;
481+ }
482+
483+ for (String id : keyIDs ) {
484+ if (id != null ) {
485+ return id ;
486+ }
487+ }
488+ return null ;
489+ }
462490 }
463491
464492 private static final class SpringJWKSetCache {
0 commit comments