diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java
index 2de3e64a815..108344fd906 100644
--- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java
+++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java
@@ -16,28 +16,9 @@
package org.springframework.security.oauth2.jwt;
-import java.net.URI;
-import java.net.URL;
-import java.time.Instant;
-import java.util.ArrayList;
-import java.util.Date;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
-
-import com.nimbusds.jose.JOSEException;
-import com.nimbusds.jose.JOSEObjectType;
-import com.nimbusds.jose.JWSAlgorithm;
-import com.nimbusds.jose.JWSHeader;
-import com.nimbusds.jose.JWSSigner;
+import com.nimbusds.jose.*;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
-import com.nimbusds.jose.jwk.JWK;
-import com.nimbusds.jose.jwk.JWKMatcher;
-import com.nimbusds.jose.jwk.JWKSelector;
-import com.nimbusds.jose.jwk.KeyType;
-import com.nimbusds.jose.jwk.KeyUse;
+import com.nimbusds.jose.jwk.*;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -45,12 +26,18 @@
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
-
+import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
+import java.net.URI;
+import java.net.URL;
+import java.time.Instant;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+
/**
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
* JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for
@@ -61,7 +48,6 @@
* NOTE: This implementation uses the Nimbus JOSE + JWT SDK.
*
* @author Joe Grandja
- * @since 5.6
* @see JwtEncoder
* @see com.nimbusds.jose.jwk.source.JWKSource
* @see com.nimbusds.jose.jwk.JWK
@@ -73,6 +59,7 @@
* Compact Serialization
* @see Nimbus
* JOSE + JWT SDK
+ * @since 5.6
*/
public final class NimbusJwtEncoder implements JwtEncoder {
@@ -84,6 +71,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
private final Map jwsSigners = new ConcurrentHashMap<>();
+ private Converter, JWK> jwkSelector;
+
private final JWKSource jwkSource;
/**
@@ -95,6 +84,10 @@ public NimbusJwtEncoder(JWKSource jwkSource) {
this.jwkSource = jwkSource;
}
+ public void setJwkSelector(Converter, JWK> jwkSelector) {
+ this.jwkSelector = jwkSelector;
+ }
+
@Override
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
Assert.notNull(parameters, "parameters cannot be null");
@@ -123,16 +116,17 @@ private JWK selectJwk(JwsHeader headers) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
}
-
- if (jwks.size() > 1) {
- throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
- "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
- }
-
if (jwks.isEmpty()) {
throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
}
+ if (null != this.jwkSelector) {
+ return this.jwkSelector.convert(jwks);
+ }
+ if (jwks.size() > 1) {
+ throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
+ "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
+ }
return jwks.get(0);
}
diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java
index e9825f0a359..985b6a827f1 100644
--- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java
+++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java
@@ -107,20 +107,6 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce
.withMessageContaining("Failed to select a JWK signing key -> key source error");
}
- @Test
- public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
- RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
- this.jwkList.add(rsaJwk);
- this.jwkList.add(rsaJwk);
-
- JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();
- JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
-
- assertThatExceptionOfType(JwtEncodingException.class)
- .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)))
- .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
- }
-
@Test
public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() {
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();