Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,28 @@

package org.springframework.security.oauth2.jwt;

import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.*;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.*;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
import com.nimbusds.jose.util.Base64;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
* JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for
Expand All @@ -61,7 +48,6 @@
* <b>NOTE:</b> This implementation uses the Nimbus JOSE + JWT SDK.
*
* @author Joe Grandja
* @since 5.6
* @see JwtEncoder
* @see com.nimbusds.jose.jwk.source.JWKSource
* @see com.nimbusds.jose.jwk.JWK
Expand All @@ -73,6 +59,7 @@
* Compact Serialization</a>
* @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus
* JOSE + JWT SDK</a>
* @since 5.6
*/
public final class NimbusJwtEncoder implements JwtEncoder {

Expand All @@ -84,6 +71,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {

private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();

private Converter<List<JWK>, JWK> jwkSelector;

private final JWKSource<SecurityContext> jwkSource;

/**
Expand All @@ -95,6 +84,10 @@ public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
this.jwkSource = jwkSource;
}

public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
this.jwkSelector = jwkSelector;
}

@Override
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
Assert.notNull(parameters, "parameters cannot be null");
Expand Down Expand Up @@ -123,16 +116,17 @@ private JWK selectJwk(JwsHeader headers) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
}

if (jwks.size() > 1) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
}

if (jwks.isEmpty()) {
throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
}
if (null != this.jwkSelector) {
return this.jwkSelector.convert(jwks);
}
if (jwks.size() > 1) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
}

return jwks.get(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,6 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce
.withMessageContaining("Failed to select a JWK signing key -> key source error");
}

@Test
public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
this.jwkList.add(rsaJwk);
this.jwkList.add(rsaJwk);

JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();

assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)))
.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
}

@Test
public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() {
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();
Expand Down