Skip to content

Commit e057ff5

Browse files
committed
Remove Deprecated Usages of RemoteJWKSet
1 parent 036f6f2 commit e057ff5

File tree

1 file changed

+59
-16
lines changed

1 file changed

+59
-16
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19+
import com.nimbusds.jose.KeySourceException;
20+
import com.nimbusds.jose.jwk.JWK;
21+
import com.nimbusds.jose.jwk.JWKSelector;
22+
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
23+
import com.nimbusds.jose.jwk.source.JWKSetUnavailableException;
24+
import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource;
1925
import java.io.IOException;
2026
import java.net.MalformedURLException;
2127
import java.net.URL;
@@ -26,6 +32,7 @@
2632
import java.util.Collections;
2733
import java.util.HashSet;
2834
import java.util.LinkedHashMap;
35+
import java.util.List;
2936
import java.util.Map;
3037
import java.util.Set;
3138
import java.util.function.Consumer;
@@ -35,11 +42,8 @@
3542

3643
import com.nimbusds.jose.JOSEException;
3744
import com.nimbusds.jose.JWSAlgorithm;
38-
import com.nimbusds.jose.RemoteKeySourceException;
3945
import com.nimbusds.jose.jwk.JWKSet;
40-
import com.nimbusds.jose.jwk.source.JWKSetCache;
4146
import com.nimbusds.jose.jwk.source.JWKSource;
42-
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
4347
import com.nimbusds.jose.proc.JWSKeySelector;
4448
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
4549
import com.nimbusds.jose.proc.SecurityContext;
@@ -165,7 +169,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
165169
.build();
166170
// @formatter:on
167171
}
168-
catch (RemoteKeySourceException ex) {
172+
catch (JWKSetUnavailableException ex) {
169173
this.logger.trace("Failed to retrieve JWK set", ex);
170174
if (ex.getCause() instanceof ParseException) {
171175
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@@ -377,11 +381,12 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
377381
}
378382

379383
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
380-
if (this.cache == null) {
381-
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
384+
URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource(toURL(jwkSetUri), jwkSetRetriever);
385+
if(this.cache == null) {
386+
return new SpringURLBasedJWKSource(urlBasedJWKSetSource);
382387
}
383-
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
384-
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
388+
SpringJWKSetCache springJWKSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
389+
return new SpringURLBasedJWKSource<>(urlBasedJWKSetSource, springJWKSetCache);
385390
}
386391

387392
JWTProcessor<SecurityContext> processor() {
@@ -414,7 +419,50 @@ private static URL toURL(String url) {
414419
}
415420
}
416421

417-
private static final class SpringJWKSetCache implements JWKSetCache {
422+
private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {
423+
424+
private final URLBasedJWKSetSource urlBasedJWKSetSource;
425+
426+
private SpringJWKSetCache springJWKSetCache;
427+
428+
private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource) {
429+
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
430+
}
431+
432+
private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource,
433+
SpringJWKSetCache springJWKSetCache) {
434+
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
435+
this.springJWKSetCache = springJWKSetCache;
436+
}
437+
438+
@Override
439+
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
440+
if (this.springJWKSetCache != null) {
441+
synchronized (this) {
442+
JWKSet jwkSet = this.springJWKSetCache.get();
443+
if (this.springJWKSetCache.requiresRefresh() || jwkSet == null) {
444+
jwkSet = fetchJWKSet(context);
445+
this.springJWKSetCache.put(jwkSet);
446+
}
447+
List<JWK> jwks = jwkSelector.select(jwkSet);
448+
if(!jwks.isEmpty()) {
449+
return jwks;
450+
}
451+
jwkSet = fetchJWKSet(context);
452+
this.springJWKSetCache.put(jwkSet);
453+
return jwkSelector.select(jwkSet);
454+
}
455+
}
456+
return jwkSelector.select(fetchJWKSet(context));
457+
}
458+
459+
private JWKSet fetchJWKSet(SecurityContext context) throws KeySourceException {
460+
return urlBasedJWKSetSource.getJWKSet(JWKSetCacheRefreshEvaluator.noRefresh(),
461+
System.currentTimeMillis(), context);
462+
}
463+
}
464+
465+
private static final class SpringJWKSetCache {
418466

419467
private final String jwkSetUri;
420468

@@ -433,27 +481,22 @@ private void updateJwkSetFromCache() {
433481
if (cachedJwkSet != null) {
434482
try {
435483
this.jwkSet = JWKSet.parse(cachedJwkSet);
436-
}
437-
catch (ParseException ignored) {
484+
} catch (ParseException ignored) {
438485
// Ignore invalid cache value
439486
}
440487
}
441488
}
442489

443-
// Note: Only called from inside a synchronized block in RemoteJWKSet.
444-
@Override
490+
// Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
445491
public void put(JWKSet jwkSet) {
446492
this.jwkSet = jwkSet;
447493
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
448494
}
449495

450-
@Override
451496
public JWKSet get() {
452497
return (!requiresRefresh()) ? this.jwkSet : null;
453-
454498
}
455499

456-
@Override
457500
public boolean requiresRefresh() {
458501
return this.cache.get(this.jwkSetUri) == null;
459502
}

0 commit comments

Comments
 (0)