Skip to content

Commit 6fbd038

Browse files
committed
Jwt client authentication converter detects new key
Closes gh-9814
1 parent 700bda6 commit 6fbd038

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
8080

8181
private final Function<ClientRegistration, JWK> jwkResolver;
8282

83-
private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>();
83+
private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();
8484

8585
/**
8686
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
@@ -140,12 +140,16 @@ public MultiValueMap<String, String> convert(T authorizationGrantRequest) {
140140
JoseHeader joseHeader = headersBuilder.build();
141141
JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
142142

143-
NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
144-
(clientRegistrationId) -> {
143+
JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(),
144+
(clientRegistrationId, currentJwsEncoderHolder) -> {
145+
if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
146+
return currentJwsEncoderHolder;
147+
}
145148
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
146-
return new NimbusJwsEncoder(jwkSource);
149+
return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk);
147150
});
148151

152+
NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
149153
Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
150154

151155
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
@@ -180,4 +184,25 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) {
180184
return jwsAlgorithm;
181185
}
182186

187+
private static final class JwsEncoderHolder {
188+
189+
private final NimbusJwsEncoder jwsEncoder;
190+
191+
private final JWK jwk;
192+
193+
private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) {
194+
this.jwsEncoder = jwsEncoder;
195+
this.jwk = jwk;
196+
}
197+
198+
private NimbusJwsEncoder getJwsEncoder() {
199+
return this.jwsEncoder;
200+
}
201+
202+
private JWK getJwk() {
203+
return this.jwk;
204+
}
205+
206+
}
207+
183208
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
package org.springframework.security.oauth2.client.endpoint;
1818

19+
import java.security.KeyPair;
20+
import java.security.KeyPairGenerator;
21+
import java.security.interfaces.RSAPrivateKey;
22+
import java.security.interfaces.RSAPublicKey;
1923
import java.util.Collections;
24+
import java.util.UUID;
2025
import java.util.function.Function;
2126

2227
import com.nimbusds.jose.jwk.JWK;
@@ -42,6 +47,7 @@
4247
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4348
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
4449
import static org.mockito.ArgumentMatchers.any;
50+
import static org.mockito.ArgumentMatchers.eq;
4551
import static org.mockito.BDDMockito.given;
4652
import static org.mockito.Mockito.mock;
4753
import static org.mockito.Mockito.verifyNoInteractions;
@@ -172,4 +178,54 @@ public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized()
172178
assertThat(jws.getExpiresAt()).isNotNull();
173179
}
174180

181+
// gh-9814
182+
@Test
183+
public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
184+
// @formatter:off
185+
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
186+
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
187+
.build();
188+
// @formatter:on
189+
190+
RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
191+
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);
192+
193+
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
194+
clientRegistration);
195+
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
196+
197+
String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
198+
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
199+
jwtDecoder.decode(encodedJws);
200+
201+
RSAKey rsaJwk2 = generateRsaJwk();
202+
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);
203+
204+
parameters = this.converter.convert(clientCredentialsGrantRequest);
205+
206+
encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
207+
jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
208+
jwtDecoder.decode(encodedJws);
209+
}
210+
211+
private static RSAKey generateRsaJwk() {
212+
KeyPair keyPair;
213+
try {
214+
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
215+
keyPairGenerator.initialize(2048);
216+
keyPair = keyPairGenerator.generateKeyPair();
217+
}
218+
catch (Exception ex) {
219+
throw new IllegalStateException(ex);
220+
}
221+
RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
222+
RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
223+
// @formatter:off
224+
return new RSAKey.Builder(publicKey)
225+
.privateKey(privateKey)
226+
.keyID(UUID.randomUUID().toString())
227+
.build();
228+
// @formatter:on
229+
}
230+
175231
}

0 commit comments

Comments
 (0)