|
16 | 16 |
|
17 | 17 | package org.springframework.security.oauth2.client.endpoint;
|
18 | 18 |
|
| 19 | +import java.security.KeyPair; |
| 20 | +import java.security.KeyPairGenerator; |
| 21 | +import java.security.interfaces.RSAPrivateKey; |
| 22 | +import java.security.interfaces.RSAPublicKey; |
19 | 23 | import java.util.Collections;
|
| 24 | +import java.util.UUID; |
20 | 25 | import java.util.function.Function;
|
21 | 26 |
|
22 | 27 | import com.nimbusds.jose.jwk.JWK;
|
|
42 | 47 | import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
43 | 48 | import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
44 | 49 | import static org.mockito.ArgumentMatchers.any;
|
| 50 | +import static org.mockito.ArgumentMatchers.eq; |
45 | 51 | import static org.mockito.BDDMockito.given;
|
46 | 52 | import static org.mockito.Mockito.mock;
|
47 | 53 | import static org.mockito.Mockito.verifyNoInteractions;
|
@@ -172,4 +178,54 @@ public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized()
|
172 | 178 | assertThat(jws.getExpiresAt()).isNotNull();
|
173 | 179 | }
|
174 | 180 |
|
| 181 | + // gh-9814 |
| 182 | + @Test |
| 183 | + public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception { |
| 184 | + // @formatter:off |
| 185 | + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() |
| 186 | + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) |
| 187 | + .build(); |
| 188 | + // @formatter:on |
| 189 | + |
| 190 | + RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK; |
| 191 | + given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1); |
| 192 | + |
| 193 | + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( |
| 194 | + clientRegistration); |
| 195 | + MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest); |
| 196 | + |
| 197 | + String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); |
| 198 | + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build(); |
| 199 | + jwtDecoder.decode(encodedJws); |
| 200 | + |
| 201 | + RSAKey rsaJwk2 = generateRsaJwk(); |
| 202 | + given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2); |
| 203 | + |
| 204 | + parameters = this.converter.convert(clientCredentialsGrantRequest); |
| 205 | + |
| 206 | + encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); |
| 207 | + jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build(); |
| 208 | + jwtDecoder.decode(encodedJws); |
| 209 | + } |
| 210 | + |
| 211 | + private static RSAKey generateRsaJwk() { |
| 212 | + KeyPair keyPair; |
| 213 | + try { |
| 214 | + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); |
| 215 | + keyPairGenerator.initialize(2048); |
| 216 | + keyPair = keyPairGenerator.generateKeyPair(); |
| 217 | + } |
| 218 | + catch (Exception ex) { |
| 219 | + throw new IllegalStateException(ex); |
| 220 | + } |
| 221 | + RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic(); |
| 222 | + RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate(); |
| 223 | + // @formatter:off |
| 224 | + return new RSAKey.Builder(publicKey) |
| 225 | + .privateKey(privateKey) |
| 226 | + .keyID(UUID.randomUUID().toString()) |
| 227 | + .build(); |
| 228 | + // @formatter:on |
| 229 | + } |
| 230 | + |
175 | 231 | }
|
0 commit comments