Skip to content

Commit 67f663f

Browse files
committed
Polish Nimbus JWK Source Implementation
Issue gh-16251
1 parent 6d16691 commit 67f663f

File tree

2 files changed

+46
-116
lines changed

2 files changed

+46
-116
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 45 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,14 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import com.nimbusds.jose.KeySourceException;
20-
import com.nimbusds.jose.jwk.JWK;
21-
import com.nimbusds.jose.jwk.JWKMatcher;
22-
import com.nimbusds.jose.jwk.JWKSelector;
23-
import com.nimbusds.jose.jwk.source.JWKSetParseException;
24-
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
25-
import java.io.IOException;
26-
import java.net.MalformedURLException;
27-
import java.net.URL;
19+
import java.net.URI;
2820
import java.security.interfaces.RSAPublicKey;
2921
import java.text.ParseException;
3022
import java.util.Arrays;
3123
import java.util.Collection;
3224
import java.util.Collections;
3325
import java.util.HashSet;
3426
import java.util.LinkedHashMap;
35-
import java.util.List;
3627
import java.util.Map;
3728
import java.util.Set;
3829
import java.util.concurrent.locks.ReentrantLock;
@@ -43,7 +34,12 @@
4334

4435
import com.nimbusds.jose.JOSEException;
4536
import com.nimbusds.jose.JWSAlgorithm;
37+
import com.nimbusds.jose.KeySourceException;
38+
import com.nimbusds.jose.RemoteKeySourceException;
4639
import com.nimbusds.jose.jwk.JWKSet;
40+
import com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource;
41+
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
42+
import com.nimbusds.jose.jwk.source.JWKSetSource;
4743
import com.nimbusds.jose.jwk.source.JWKSource;
4844
import com.nimbusds.jose.proc.JWSKeySelector;
4945
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
@@ -170,7 +166,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
170166
.build();
171167
// @formatter:on
172168
}
173-
catch (KeySourceException ex) {
169+
catch (RemoteKeySourceException ex) {
174170
this.logger.trace("Failed to retrieve JWK set", ex);
175171
if (ex.getCause() instanceof ParseException) {
176172
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@@ -383,7 +379,7 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
383379

384380
JWKSource<SecurityContext> jwkSource() {
385381
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
386-
return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
382+
return new JWKSetBasedJWKSource<>(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri));
387383
}
388384

389385
JWTProcessor<SecurityContext> processor() {
@@ -405,16 +401,7 @@ public NimbusJwtDecoder build() {
405401
return new NimbusJwtDecoder(processor());
406402
}
407403

408-
private static URL toURL(String url) {
409-
try {
410-
return new URL(url);
411-
}
412-
catch (MalformedURLException ex) {
413-
throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
414-
}
415-
}
416-
417-
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
404+
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {
418405

419406
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
420407

@@ -424,120 +411,63 @@ private static final class SpringJWKSource<C extends SecurityContext> implements
424411

425412
private final Cache cache;
426413

427-
private final URL url;
428-
429414
private final String jwkSetUri;
430415

431-
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
416+
private JWKSet jwkSet;
417+
418+
private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) {
432419
Assert.notNull(restOperations, "restOperations cannot be null");
433420
this.restOperations = restOperations;
434421
this.cache = cache;
435-
this.url = url;
436422
this.jwkSetUri = jwkSetUri;
437-
}
438-
439-
440-
@Override
441-
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
442-
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
443-
JWKSet jwkSet = null;
444-
if (cachedJwkSet != null) {
445-
jwkSet = parse(cachedJwkSet);
446-
}
447-
if (jwkSet == null) {
448-
if(reentrantLock.tryLock()) {
449-
try {
450-
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
451-
if (cachedJwkSetAfterLock != null) {
452-
jwkSet = parse(cachedJwkSetAfterLock);
453-
}
454-
if(jwkSet == null) {
455-
try {
456-
jwkSet = fetchJWKSet();
457-
} catch (IOException e) {
458-
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
459-
}
460-
}
461-
} finally {
462-
reentrantLock.unlock();
463-
}
464-
}
465-
}
466-
List<JWK> matches = jwkSelector.select(jwkSet);
467-
if(!matches.isEmpty()) {
468-
return matches;
469-
}
470-
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
471-
if (soughtKeyID == null) {
472-
return Collections.emptyList();
473-
}
474-
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
475-
return Collections.emptyList();
476-
}
477-
478-
if(reentrantLock.tryLock()) {
423+
String jwks = this.cache.get(this.jwkSetUri, String.class);
424+
if (jwks != null) {
479425
try {
480-
String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
481-
JWKSet cacheJwkSet = parse(jwkSetUri);
482-
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
483-
try {
484-
jwkSet = fetchJWKSet();
485-
} catch (IOException e) {
486-
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
487-
}
488-
} else if (jwkSetUri != null) {
489-
jwkSet = parse(jwkSetUri);
490-
}
491-
} finally {
492-
reentrantLock.unlock();
426+
this.jwkSet = JWKSet.parse(jwks);
427+
}
428+
catch (ParseException ignored) {
429+
// Ignore invalid cache value
493430
}
494431
}
495-
if(jwkSet == null) {
496-
return Collections.emptyList();
497-
}
498-
return jwkSelector.select(jwkSet);
499432
}
500433

501-
private JWKSet fetchJWKSet() throws IOException, KeySourceException {
434+
private String fetchJwks() throws Exception {
502435
HttpHeaders headers = new HttpHeaders();
503436
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
504-
ResponseEntity<String> response = getResponse(headers);
505-
if (response.getStatusCode().value() != 200) {
506-
throw new IOException(response.toString());
507-
}
508-
try {
509-
String jwkSet = response.getBody();
510-
this.cache.put(this.jwkSetUri, jwkSet);
511-
return JWKSet.parse(jwkSet);
512-
} catch (ParseException e) {
513-
throw new JWKSetParseException("Unable to parse JWK set", e);
514-
}
437+
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri));
438+
ResponseEntity<String> response = this.restOperations.exchange(request, String.class);
439+
String jwks = response.getBody();
440+
this.jwkSet = JWKSet.parse(jwks);
441+
return jwks;
515442
}
516443

517-
private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
444+
@Override
445+
public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context)
446+
throws KeySourceException {
518447
try {
519-
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
520-
return this.restOperations.exchange(request, String.class);
521-
} catch (Exception ex) {
522-
throw new IOException(ex);
448+
this.reentrantLock.lock();
449+
if (refreshEvaluator.requiresRefresh(this.jwkSet)) {
450+
this.cache.invalidate();
451+
}
452+
this.cache.get(this.jwkSetUri, this::fetchJwks);
453+
return this.jwkSet;
523454
}
524-
}
525-
526-
private JWKSet parse(String cachedJwkSet) {
527-
JWKSet jwkSet = null;
528-
try {
529-
jwkSet = JWKSet.parse(cachedJwkSet);
530-
} catch (ParseException ignored) {
531-
// Ignore invalid cache value
455+
catch (Cache.ValueRetrievalException ex) {
456+
if (ex.getCause() instanceof RemoteKeySourceException keys) {
457+
throw keys;
458+
}
459+
throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause());
460+
}
461+
finally {
462+
this.reentrantLock.unlock();
532463
}
533-
return jwkSet;
534464
}
535465

536-
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
537-
Set<String> keyIDs = jwkMatcher.getKeyIDs();
538-
return (keyIDs == null || keyIDs.isEmpty()) ?
539-
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
466+
@Override
467+
public void close() {
468+
540469
}
470+
541471
}
542472

543473
}

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2025 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)