1616
1717package org .springframework .security .oauth2 .jwt ;
1818
19- import com .nimbusds .jose .KeySourceException ;
20- import com .nimbusds .jose .jwk .JWK ;
21- import com .nimbusds .jose .jwk .JWKMatcher ;
22- import com .nimbusds .jose .jwk .JWKSelector ;
23- import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24- import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
25- import java .io .IOException ;
26- import java .net .MalformedURLException ;
27- import java .net .URL ;
19+ import java .net .URI ;
2820import java .security .interfaces .RSAPublicKey ;
2921import java .text .ParseException ;
3022import java .util .Arrays ;
3123import java .util .Collection ;
3224import java .util .Collections ;
3325import java .util .HashSet ;
3426import java .util .LinkedHashMap ;
35- import java .util .List ;
3627import java .util .Map ;
3728import java .util .Set ;
3829import java .util .concurrent .locks .ReentrantLock ;
4334
4435import com .nimbusds .jose .JOSEException ;
4536import com .nimbusds .jose .JWSAlgorithm ;
37+ import com .nimbusds .jose .KeySourceException ;
38+ import com .nimbusds .jose .RemoteKeySourceException ;
4639import com .nimbusds .jose .jwk .JWKSet ;
40+ import com .nimbusds .jose .jwk .source .JWKSetBasedJWKSource ;
41+ import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
42+ import com .nimbusds .jose .jwk .source .JWKSetSource ;
4743import com .nimbusds .jose .jwk .source .JWKSource ;
4844import com .nimbusds .jose .proc .JWSKeySelector ;
4945import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
@@ -170,7 +166,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
170166 .build ();
171167 // @formatter:on
172168 }
173- catch (KeySourceException ex ) {
169+ catch (RemoteKeySourceException ex ) {
174170 this .logger .trace ("Failed to retrieve JWK set" , ex );
175171 if (ex .getCause () instanceof ParseException ) {
176172 throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -383,7 +379,7 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
383379
384380 JWKSource <SecurityContext > jwkSource () {
385381 String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386- return new SpringJWKSource <>(this .restOperations , this .cache , toURL ( jwkSetUri ), jwkSetUri );
382+ return new JWKSetBasedJWKSource <>( new SpringJWKSource <>(this .restOperations , this .cache , jwkSetUri ));
387383 }
388384
389385 JWTProcessor <SecurityContext > processor () {
@@ -405,16 +401,7 @@ public NimbusJwtDecoder build() {
405401 return new NimbusJwtDecoder (processor ());
406402 }
407403
408- private static URL toURL (String url ) {
409- try {
410- return new URL (url );
411- }
412- catch (MalformedURLException ex ) {
413- throw new IllegalArgumentException ("Invalid JWK Set URL \" " + url + "\" : " + ex .getMessage (), ex );
414- }
415- }
416-
417- private static final class SpringJWKSource <C extends SecurityContext > implements JWKSource <C > {
404+ private static final class SpringJWKSource <C extends SecurityContext > implements JWKSetSource <C > {
418405
419406 private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
420407
@@ -424,120 +411,63 @@ private static final class SpringJWKSource<C extends SecurityContext> implements
424411
425412 private final Cache cache ;
426413
427- private final URL url ;
428-
429414 private final String jwkSetUri ;
430415
431- private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
416+ private JWKSet jwkSet ;
417+
418+ private SpringJWKSource (RestOperations restOperations , Cache cache , String jwkSetUri ) {
432419 Assert .notNull (restOperations , "restOperations cannot be null" );
433420 this .restOperations = restOperations ;
434421 this .cache = cache ;
435- this .url = url ;
436422 this .jwkSetUri = jwkSetUri ;
437- }
438-
439-
440- @ Override
441- public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
442- String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443- JWKSet jwkSet = null ;
444- if (cachedJwkSet != null ) {
445- jwkSet = parse (cachedJwkSet );
446- }
447- if (jwkSet == null ) {
448- if (reentrantLock .tryLock ()) {
449- try {
450- String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451- if (cachedJwkSetAfterLock != null ) {
452- jwkSet = parse (cachedJwkSetAfterLock );
453- }
454- if (jwkSet == null ) {
455- try {
456- jwkSet = fetchJWKSet ();
457- } catch (IOException e ) {
458- throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459- }
460- }
461- } finally {
462- reentrantLock .unlock ();
463- }
464- }
465- }
466- List <JWK > matches = jwkSelector .select (jwkSet );
467- if (!matches .isEmpty ()) {
468- return matches ;
469- }
470- String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471- if (soughtKeyID == null ) {
472- return Collections .emptyList ();
473- }
474- if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475- return Collections .emptyList ();
476- }
477-
478- if (reentrantLock .tryLock ()) {
423+ String jwks = this .cache .get (this .jwkSetUri , String .class );
424+ if (jwks != null ) {
479425 try {
480- String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481- JWKSet cacheJwkSet = parse (jwkSetUri );
482- if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483- try {
484- jwkSet = fetchJWKSet ();
485- } catch (IOException e ) {
486- throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487- }
488- } else if (jwkSetUri != null ) {
489- jwkSet = parse (jwkSetUri );
490- }
491- } finally {
492- reentrantLock .unlock ();
426+ this .jwkSet = JWKSet .parse (jwks );
427+ }
428+ catch (ParseException ignored ) {
429+ // Ignore invalid cache value
493430 }
494431 }
495- if (jwkSet == null ) {
496- return Collections .emptyList ();
497- }
498- return jwkSelector .select (jwkSet );
499432 }
500433
501- private JWKSet fetchJWKSet () throws IOException , KeySourceException {
434+ private String fetchJwks () throws Exception {
502435 HttpHeaders headers = new HttpHeaders ();
503436 headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
504- ResponseEntity <String > response = getResponse (headers );
505- if (response .getStatusCode ().value () != 200 ) {
506- throw new IOException (response .toString ());
507- }
508- try {
509- String jwkSet = response .getBody ();
510- this .cache .put (this .jwkSetUri , jwkSet );
511- return JWKSet .parse (jwkSet );
512- } catch (ParseException e ) {
513- throw new JWKSetParseException ("Unable to parse JWK set" , e );
514- }
437+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , URI .create (this .jwkSetUri ));
438+ ResponseEntity <String > response = this .restOperations .exchange (request , String .class );
439+ String jwks = response .getBody ();
440+ this .jwkSet = JWKSet .parse (jwks );
441+ return jwks ;
515442 }
516443
517- private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
444+ @ Override
445+ public JWKSet getJWKSet (JWKSetCacheRefreshEvaluator refreshEvaluator , long currentTime , C context )
446+ throws KeySourceException {
518447 try {
519- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this .url .toURI ());
520- return this .restOperations .exchange (request , String .class );
521- } catch (Exception ex ) {
522- throw new IOException (ex );
448+ this .reentrantLock .lock ();
449+ if (refreshEvaluator .requiresRefresh (this .jwkSet )) {
450+ this .cache .invalidate ();
451+ }
452+ this .cache .get (this .jwkSetUri , this ::fetchJwks );
453+ return this .jwkSet ;
523454 }
524- }
525-
526- private JWKSet parse ( String cachedJwkSet ) {
527- JWKSet jwkSet = null ;
528- try {
529- jwkSet = JWKSet . parse ( cachedJwkSet );
530- } catch ( ParseException ignored ) {
531- // Ignore invalid cache value
455+ catch ( Cache . ValueRetrievalException ex ) {
456+ if ( ex . getCause () instanceof RemoteKeySourceException keys ) {
457+ throw keys ;
458+ }
459+ throw new RemoteKeySourceException ( ex . getCause (). getMessage (), ex . getCause ());
460+ }
461+ finally {
462+ this . reentrantLock . unlock ();
532463 }
533- return jwkSet ;
534464 }
535465
536- private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537- Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538- return (keyIDs == null || keyIDs .isEmpty ()) ?
539- null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
466+ @ Override
467+ public void close () {
468+
540469 }
470+
541471 }
542472
543473 }
0 commit comments