Skip to content

Commit 2892885

Browse files
author
douxf
committed
enhancement for NimbusJwtEncoder to supporting key rotation
close 16170
1 parent edde7db commit 2892885

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,28 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import java.net.URI;
20-
import java.net.URL;
21-
import java.time.Instant;
22-
import java.util.ArrayList;
23-
import java.util.Date;
24-
import java.util.HashMap;
25-
import java.util.List;
26-
import java.util.Map;
27-
import java.util.Set;
28-
import java.util.concurrent.ConcurrentHashMap;
29-
30-
import com.nimbusds.jose.JOSEException;
31-
import com.nimbusds.jose.JOSEObjectType;
32-
import com.nimbusds.jose.JWSAlgorithm;
33-
import com.nimbusds.jose.JWSHeader;
34-
import com.nimbusds.jose.JWSSigner;
19+
import com.nimbusds.jose.*;
3520
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
36-
import com.nimbusds.jose.jwk.JWK;
37-
import com.nimbusds.jose.jwk.JWKMatcher;
38-
import com.nimbusds.jose.jwk.JWKSelector;
39-
import com.nimbusds.jose.jwk.KeyType;
40-
import com.nimbusds.jose.jwk.KeyUse;
21+
import com.nimbusds.jose.jwk.*;
4122
import com.nimbusds.jose.jwk.source.JWKSource;
4223
import com.nimbusds.jose.proc.SecurityContext;
4324
import com.nimbusds.jose.produce.JWSSignerFactory;
4425
import com.nimbusds.jose.util.Base64;
4526
import com.nimbusds.jose.util.Base64URL;
4627
import com.nimbusds.jwt.JWTClaimsSet;
4728
import com.nimbusds.jwt.SignedJWT;
48-
29+
import org.springframework.core.convert.converter.Converter;
4930
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5031
import org.springframework.util.Assert;
5132
import org.springframework.util.CollectionUtils;
5233
import org.springframework.util.StringUtils;
5334

35+
import java.net.URI;
36+
import java.net.URL;
37+
import java.time.Instant;
38+
import java.util.*;
39+
import java.util.concurrent.ConcurrentHashMap;
40+
5441
/**
5542
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
5643
* JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for
@@ -61,7 +48,6 @@
6148
* <b>NOTE:</b> This implementation uses the Nimbus JOSE + JWT SDK.
6249
*
6350
* @author Joe Grandja
64-
* @since 5.6
6551
* @see JwtEncoder
6652
* @see com.nimbusds.jose.jwk.source.JWKSource
6753
* @see com.nimbusds.jose.jwk.JWK
@@ -73,6 +59,7 @@
7359
* Compact Serialization</a>
7460
* @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus
7561
* JOSE + JWT SDK</a>
62+
* @since 5.6
7663
*/
7764
public final class NimbusJwtEncoder implements JwtEncoder {
7865

@@ -84,6 +71,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8471

8572
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
8673

74+
private Converter<List<JWK>, JWK> jwkSelector;
75+
8776
private final JWKSource<SecurityContext> jwkSource;
8877

8978
/**
@@ -95,6 +84,10 @@ public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
9584
this.jwkSource = jwkSource;
9685
}
9786

87+
public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
88+
this.jwkSelector = jwkSelector;
89+
}
90+
9891
@Override
9992
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
10093
Assert.notNull(parameters, "parameters cannot be null");
@@ -118,21 +111,21 @@ private JWK selectJwk(JwsHeader headers) {
118111
try {
119112
JWKSelector jwkSelector = new JWKSelector(createJwkMatcher(headers));
120113
jwks = this.jwkSource.get(jwkSelector, null);
121-
}
122-
catch (Exception ex) {
114+
} catch (Exception ex) {
123115
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
124116
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
125117
}
126-
127-
if (jwks.size() > 1) {
128-
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
129-
"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
130-
}
131-
132118
if (jwks.isEmpty()) {
133119
throw new JwtEncodingException(
134120
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
135121
}
122+
if (null != this.jwkSelector) {
123+
return this.jwkSelector.convert(jwks);
124+
}
125+
if (jwks.size() > 1) {
126+
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
127+
"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
128+
}
136129

137130
return jwks.get(0);
138131
}
@@ -146,8 +139,7 @@ private String serialize(JwsHeader headers, JwtClaimsSet claims, JWK jwk) {
146139
SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet);
147140
try {
148141
signedJwt.sign(jwsSigner);
149-
}
150-
catch (JOSEException ex) {
142+
} catch (JOSEException ex) {
151143
throw new JwtEncodingException(
152144
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to sign the JWT -> " + ex.getMessage()), ex);
153145
}
@@ -167,8 +159,7 @@ private static JWKMatcher createJwkMatcher(JwsHeader headers) {
167159
.x509CertSHA256Thumbprint(Base64URL.from(headers.getX509SHA256Thumbprint()))
168160
.build();
169161
// @formatter:on
170-
}
171-
else if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) {
162+
} else if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) {
172163
// @formatter:off
173164
return new JWKMatcher.Builder()
174165
.keyType(KeyType.forAlgorithm(jwsAlgorithm))
@@ -206,8 +197,7 @@ private static JwsHeader addKeyIdentifierHeadersIfNecessary(JwsHeader headers, J
206197
private static JWSSigner createSigner(JWK jwk) {
207198
try {
208199
return JWS_SIGNER_FACTORY.createJWSSigner(jwk);
209-
}
210-
catch (JOSEException ex) {
200+
} catch (JOSEException ex) {
211201
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
212202
"Failed to create a JWS Signer -> " + ex.getMessage()), ex);
213203
}
@@ -224,8 +214,7 @@ private static JWSHeader convert(JwsHeader headers) {
224214
if (!CollectionUtils.isEmpty(jwk)) {
225215
try {
226216
builder.jwk(JWK.parse(jwk));
227-
}
228-
catch (Exception ex) {
217+
} catch (Exception ex) {
229218
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
230219
"Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header"), ex);
231220
}
@@ -342,8 +331,7 @@ private static JWTClaimsSet convert(JwtClaimsSet claims) {
342331
private static URI convertAsURI(String header, URL url) {
343332
try {
344333
return url.toURI();
345-
}
346-
catch (Exception ex) {
334+
} catch (Exception ex) {
347335
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
348336
"Unable to convert '" + header + "' JOSE header to a URI"), ex);
349337
}

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,27 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws
121121
.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
122122
}
123123

124+
@Test
125+
public void encodeWhenJwkMultipleSelectedWithJwkSelector() throws Exception {
126+
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
127+
this.jwkList.add(rsaJwk);
128+
this.jwkList.add(rsaJwk);
129+
this.jwtEncoder.setJwkSelector(jwkSelector -> jwkSelector.get(0));
130+
131+
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();
132+
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
133+
134+
Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.from(jwtClaimsSet));
135+
assertThat(encodedJws.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.RS256);
136+
137+
this.jwtEncoder.setJwkSelector(jwkSelector -> jwkSelector.get(jwkSelector.size()-1));
138+
jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
139+
encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.from(jwtClaimsSet));
140+
assertThat(encodedJws.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.RS256);
141+
this.jwtEncoder.setJwkSelector(null);
142+
}
143+
144+
124145
@Test
125146
public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() {
126147
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build();

0 commit comments

Comments
 (0)