diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index 2de3e64a815..108344fd906 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -16,28 +16,9 @@ package org.springframework.security.oauth2.jwt; -import java.net.URI; -import java.net.URL; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Date; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; - -import com.nimbusds.jose.JOSEException; -import com.nimbusds.jose.JOSEObjectType; -import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.JWSHeader; -import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.*; import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; -import com.nimbusds.jose.jwk.JWK; -import com.nimbusds.jose.jwk.JWKMatcher; -import com.nimbusds.jose.jwk.JWKSelector; -import com.nimbusds.jose.jwk.KeyType; -import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.*; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.produce.JWSSignerFactory; @@ -45,12 +26,18 @@ import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; - +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import java.net.URI; +import java.net.URL; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + /** * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the * JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for @@ -61,7 +48,6 @@ * NOTE: This implementation uses the Nimbus JOSE + JWT SDK. * * @author Joe Grandja - * @since 5.6 * @see JwtEncoder * @see com.nimbusds.jose.jwk.source.JWKSource * @see com.nimbusds.jose.jwk.JWK @@ -73,6 +59,7 @@ * Compact Serialization * @see Nimbus * JOSE + JWT SDK + * @since 5.6 */ public final class NimbusJwtEncoder implements JwtEncoder { @@ -84,6 +71,8 @@ public final class NimbusJwtEncoder implements JwtEncoder { private final Map jwsSigners = new ConcurrentHashMap<>(); + private Converter, JWK> jwkSelector; + private final JWKSource jwkSource; /** @@ -95,6 +84,10 @@ public NimbusJwtEncoder(JWKSource jwkSource) { this.jwkSource = jwkSource; } + public void setJwkSelector(Converter, JWK> jwkSelector) { + this.jwkSelector = jwkSelector; + } + @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); @@ -123,16 +116,17 @@ private JWK selectJwk(JwsHeader headers) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key -> " + ex.getMessage()), ex); } - - if (jwks.size() > 1) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'")); - } - if (jwks.isEmpty()) { throw new JwtEncodingException( String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); } + if (null != this.jwkSelector) { + return this.jwkSelector.convert(jwks); + } + if (jwks.size() > 1) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'")); + } return jwks.get(0); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index e9825f0a359..985b6a827f1 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -107,20 +107,6 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce .withMessageContaining("Failed to select a JWK signing key -> key source error"); } - @Test - public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception { - RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; - this.jwkList.add(rsaJwk); - this.jwkList.add(rsaJwk); - - JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); - JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - - assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) - .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); - } - @Test public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();