From 13c92ef0173f30f483eb689b051a0b12f52eef13 Mon Sep 17 00:00:00 2001 From: douxiaofeng99 <18600127780@163.com> Date: Tue, 11 Feb 2025 16:36:59 +0800 Subject: [PATCH 1/2] Support JWK Selection Strategy Closes gh-16170 Signed-off-by: douxiaofeng99 <18600127780@163.com> --- .../security/oauth2/jwt/NimbusJwtEncoder.java | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index 2de3e64a81..bba502dfd9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -86,6 +87,19 @@ public final class NimbusJwtEncoder implements JwtEncoder { private final JWKSource jwkSource; + private Converter, JWK> jwkSelector= (jwks)->{ + if (jwks.size() > 1) { + throw new JwtEncodingException(String.format( + "Failed to select a key since there are multiple for the signing algorithm [%s]; " + + "please specify a selector in NimbusJwsEncoder#setJwkSelector",jwks.get(0).getAlgorithm())); + } + if (jwks.isEmpty()) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); + } + return jwks.get(0); + }; + /** * Constructs a {@code NimbusJwtEncoder} using the provided parameters. * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} @@ -94,6 +108,18 @@ public NimbusJwtEncoder(JWKSource jwkSource) { Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; } + /** + * Use this strategy to reduce the list of matching JWKs down to a since one. + *

For example, you can call {@code setJwkSelector(List::getFirst)} in order + * to have this encoder select the first match. + * + *

By default, the class with throw an exception if there is more than one result. + * @since 6.5 + */ + public void setJwkSelector(Converter, JWK> jwkSelector) { + if(null!=jwkSelector) + this.jwkSelector = jwkSelector; + } @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { @@ -123,18 +149,7 @@ private JWK selectJwk(JwsHeader headers) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key -> " + ex.getMessage()), ex); } - - if (jwks.size() > 1) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'")); - } - - if (jwks.isEmpty()) { - throw new JwtEncodingException( - String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); - } - - return jwks.get(0); + return this.jwkSelector.convert(jwks); } private String serialize(JwsHeader headers, JwtClaimsSet claims, JWK jwk) { From d799c132e62374b07573af62825950e3ed2e0465 Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Fri, 14 Feb 2025 09:35:21 -0700 Subject: [PATCH 2/2] Polish setJwkSelector Make so that it runs only when selection is needed. Require the provided selector be non-null. Add Tests. Issue gh-16170 --- .../security/oauth2/jwt/NimbusJwtEncoder.java | 39 ++++++------ .../security/oauth2/jose/TestJwks.java | 6 +- .../oauth2/jwt/NimbusJwtEncoderTests.java | 59 ++++++++++++++++++- 3 files changed, 83 insertions(+), 21 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index bba502dfd9..fb0468fa9b 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -87,17 +87,12 @@ public final class NimbusJwtEncoder implements JwtEncoder { private final JWKSource jwkSource; - private Converter, JWK> jwkSelector= (jwks)->{ - if (jwks.size() > 1) { - throw new JwtEncodingException(String.format( - "Failed to select a key since there are multiple for the signing algorithm [%s]; " + - "please specify a selector in NimbusJwsEncoder#setJwkSelector",jwks.get(0).getAlgorithm())); - } - if (jwks.isEmpty()) { - throw new JwtEncodingException( - String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); - } - return jwks.get(0); + private Converter, JWK> jwkSelector = (jwks) -> { + throw new JwtEncodingException( + String.format( + "Failed to select a key since there are multiple for the signing algorithm [%s]; " + + "please specify a selector in NimbusJwsEncoder#setJwkSelector", + jwks.get(0).getAlgorithm())); }; /** @@ -108,17 +103,20 @@ public NimbusJwtEncoder(JWKSource jwkSource) { Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; } + /** - * Use this strategy to reduce the list of matching JWKs down to a since one. - *

For example, you can call {@code setJwkSelector(List::getFirst)} in order - * to have this encoder select the first match. + * Use this strategy to reduce the list of matching JWKs when there is more than one. + *

+ * For example, you can call {@code setJwkSelector(List::getFirst)} in order to have + * this encoder select the first match. * - *

By default, the class with throw an exception if there is more than one result. + *

+ * By default, the class with throw an exception. * @since 6.5 */ public void setJwkSelector(Converter, JWK> jwkSelector) { - if(null!=jwkSelector) - this.jwkSelector = jwkSelector; + Assert.notNull(jwkSelector, "jwkSelector cannot be null"); + this.jwkSelector = jwkSelector; } @Override @@ -149,6 +147,13 @@ private JWK selectJwk(JwsHeader headers) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key -> " + ex.getMessage()), ex); } + if (jwks.isEmpty()) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); + } + if (jwks.size() == 1) { + return jwks.get(0); + } return this.jwkSelector.convert(jwks); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java index 412adbfd4d..d0426a4533 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,10 @@ public final class TestJwks { private TestJwks() { } + public static RSAKey.Builder rsa() { + return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY); + } + public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) { // @formatter:off return new RSAKey.Builder(publicKey) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index e9825f0a35..ab17156eac 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; +import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; @@ -39,6 +40,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -51,6 +53,8 @@ import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * Tests for {@link NimbusJwtEncoder}. @@ -109,7 +113,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce @Test public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception { - RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); this.jwkList.add(rsaJwk); this.jwkList.add(rsaJwk); @@ -118,7 +122,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) - .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); + .withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]"); } @Test @@ -291,6 +295,55 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID()); } + @Test + public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception { + JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk)); + Converter, JWK> selector = mock(Converter.class); + given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + jwtEncoder.encode(JwtEncoderParameters.from(claims)); + + verify(selector).convert(any()); + } + + @Test + public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception { + JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of(jwk)); + Converter, JWK> selector = mock(Converter.class); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + jwtEncoder.encode(JwtEncoderParameters.from(claims)); + + verifyNoInteractions(selector); + } + + @Test + public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception { + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of()); + Converter, JWK> selector = mock(Converter.class); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims))); + + verifyNoInteractions(selector); + } + private static final class JwkListResultCaptor implements Answer> { private List result;