Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

package org.springframework.security.oauth2.jwt;

import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
import com.nimbusds.jose.jwk.source.JWKSetUnavailableException;
import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
Expand All @@ -26,6 +32,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
Expand All @@ -35,11 +42,8 @@

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
Expand Down Expand Up @@ -165,7 +169,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
.build();
// @formatter:on
}
catch (RemoteKeySourceException ex) {
catch (JWKSetUnavailableException ex) {
this.logger.trace("Failed to retrieve JWK set", ex);
if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
Expand Down Expand Up @@ -377,11 +381,12 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
}

JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
if (this.cache == null) {
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource(toURL(jwkSetUri), jwkSetRetriever);
if(this.cache == null) {
return new SpringURLBasedJWKSource(urlBasedJWKSetSource);
}
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
SpringJWKSetCache springJWKSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new SpringURLBasedJWKSource<>(urlBasedJWKSetSource, springJWKSetCache);
}

JWTProcessor<SecurityContext> processor() {
Expand Down Expand Up @@ -414,7 +419,50 @@ private static URL toURL(String url) {
}
}

private static final class SpringJWKSetCache implements JWKSetCache {
private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {

private final URLBasedJWKSetSource urlBasedJWKSetSource;

private SpringJWKSetCache springJWKSetCache;

private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource) {
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
}

private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource,
SpringJWKSetCache springJWKSetCache) {
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
this.springJWKSetCache = springJWKSetCache;
}

@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
if (this.springJWKSetCache != null) {
synchronized (this) {
JWKSet jwkSet = this.springJWKSetCache.get();
if (this.springJWKSetCache.requiresRefresh() || jwkSet == null) {
jwkSet = fetchJWKSet(context);
this.springJWKSetCache.put(jwkSet);
}
List<JWK> jwks = jwkSelector.select(jwkSet);
if(!jwks.isEmpty()) {
return jwks;
}
jwkSet = fetchJWKSet(context);
this.springJWKSetCache.put(jwkSet);
return jwkSelector.select(jwkSet);
}
}
return jwkSelector.select(fetchJWKSet(context));
}

private JWKSet fetchJWKSet(SecurityContext context) throws KeySourceException {
return urlBasedJWKSetSource.getJWKSet(JWKSetCacheRefreshEvaluator.noRefresh(),
System.currentTimeMillis(), context);
}
}

private static final class SpringJWKSetCache {

private final String jwkSetUri;

Expand All @@ -433,27 +481,22 @@ private void updateJwkSetFromCache() {
if (cachedJwkSet != null) {
try {
this.jwkSet = JWKSet.parse(cachedJwkSet);
}
catch (ParseException ignored) {
} catch (ParseException ignored) {
// Ignore invalid cache value
}
}
}

// Note: Only called from inside a synchronized block in RemoteJWKSet.
@Override
// Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
public void put(JWKSet jwkSet) {
this.jwkSet = jwkSet;
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
}

@Override
public JWKSet get() {
return (!requiresRefresh()) ? this.jwkSet : null;

}

@Override
public boolean requiresRefresh() {
return this.cache.get(this.jwkSetUri) == null;
}
Expand Down