Skip to content

Commit 2f9344d

Browse files
authored
Merge branch 'spring-projects:main' into main
2 parents d656dc5 + 51ce91f commit 2f9344d

File tree

4 files changed

+99
-25
lines changed

4 files changed

+99
-25
lines changed

config/src/test/java/org/springframework/security/SpringSecurityCoreVersionSerializableTests.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.nio.file.Paths;
3535
import java.time.Instant;
3636
import java.util.ArrayList;
37+
import java.util.Arrays;
3738
import java.util.Collection;
3839
import java.util.Date;
3940
import java.util.HashMap;
@@ -42,7 +43,6 @@
4243
import java.util.Locale;
4344
import java.util.Map;
4445
import java.util.Set;
45-
import java.util.stream.Collectors;
4646
import java.util.stream.Stream;
4747

4848
import jakarta.servlet.http.Cookie;
@@ -762,18 +762,14 @@ static Stream<Path> getFilesToDeserialize() throws IOException {
762762
}
763763

764764
@Test
765-
void listClassesMissingSerialVersion() throws Exception {
765+
void allSerializableClassesShouldHaveSerialVersionOrSuppressWarnings() throws Exception {
766766
ClassPathScanningCandidateComponentProvider provider = new ClassPathScanningCandidateComponentProvider(false);
767767
provider.addIncludeFilter(new AssignableTypeFilter(Serializable.class));
768768
List<Class<?>> classes = new ArrayList<>();
769769

770770
Set<BeanDefinition> components = provider.findCandidateComponents("org/springframework/security");
771771
for (BeanDefinition component : components) {
772772
Class<?> clazz = Class.forName(component.getBeanClassName());
773-
boolean isAbstract = Modifier.isAbstract(clazz.getModifiers());
774-
if (isAbstract) {
775-
continue;
776-
}
777773
if (clazz.isEnum()) {
778774
continue;
779775
}
@@ -783,15 +779,16 @@ void listClassesMissingSerialVersion() throws Exception {
783779
boolean hasSerialVersion = Stream.of(clazz.getDeclaredFields())
784780
.map(Field::getName)
785781
.anyMatch((n) -> n.equals("serialVersionUID"));
786-
if (!hasSerialVersion) {
782+
SuppressWarnings suppressWarnings = clazz.getAnnotation(SuppressWarnings.class);
783+
boolean hasSerialIgnore = suppressWarnings == null
784+
|| Arrays.asList(suppressWarnings.value()).contains("Serial");
785+
if (!hasSerialVersion && !hasSerialIgnore) {
787786
classes.add(clazz);
788787
}
789788
}
790-
if (!classes.isEmpty()) {
791-
System.out
792-
.println("Found " + classes.size() + " Serializable classes that don't declare a seriallVersionUID");
793-
System.out.println(classes.stream().map(Class::getName).collect(Collectors.joining("\r\n")));
794-
}
789+
assertThat(classes)
790+
.describedAs("Found Serializable classes that are either missing a serialVersionUID or a @SuppressWarnings")
791+
.isEmpty();
795792
}
796793

797794
static Stream<Class<?>> getClassesToSerialize() throws Exception {

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -46,6 +46,7 @@
4646
import com.nimbusds.jwt.JWTClaimsSet;
4747
import com.nimbusds.jwt.SignedJWT;
4848

49+
import org.springframework.core.convert.converter.Converter;
4950
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5051
import org.springframework.util.Assert;
5152
import org.springframework.util.CollectionUtils;
@@ -86,6 +87,14 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8687

8788
private final JWKSource<SecurityContext> jwkSource;
8889

90+
private Converter<List<JWK>, JWK> jwkSelector = (jwks) -> {
91+
throw new JwtEncodingException(
92+
String.format(
93+
"Failed to select a key since there are multiple for the signing algorithm [%s]; "
94+
+ "please specify a selector in NimbusJwsEncoder#setJwkSelector",
95+
jwks.get(0).getAlgorithm()));
96+
};
97+
8998
/**
9099
* Constructs a {@code NimbusJwtEncoder} using the provided parameters.
91100
* @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
@@ -95,6 +104,21 @@ public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
95104
this.jwkSource = jwkSource;
96105
}
97106

107+
/**
108+
* Use this strategy to reduce the list of matching JWKs when there is more than one.
109+
* <p>
110+
* For example, you can call {@code setJwkSelector(List::getFirst)} in order to have
111+
* this encoder select the first match.
112+
*
113+
* <p>
114+
* By default, the class with throw an exception.
115+
* @since 6.5
116+
*/
117+
public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
118+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
119+
this.jwkSelector = jwkSelector;
120+
}
121+
98122
@Override
99123
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
100124
Assert.notNull(parameters, "parameters cannot be null");
@@ -123,18 +147,14 @@ private JWK selectJwk(JwsHeader headers) {
123147
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
124148
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
125149
}
126-
127-
if (jwks.size() > 1) {
128-
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
129-
"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
130-
}
131-
132150
if (jwks.isEmpty()) {
133151
throw new JwtEncodingException(
134152
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
135153
}
136-
137-
return jwks.get(0);
154+
if (jwks.size() == 1) {
155+
return jwks.get(0);
156+
}
157+
return this.jwkSelector.convert(jwks);
138158
}
139159

140160
private String serialize(JwsHeader headers, JwtClaimsSet claims, JWK jwk) {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -59,6 +59,10 @@ public final class TestJwks {
5959
private TestJwks() {
6060
}
6161

62+
public static RSAKey.Builder rsa() {
63+
return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY);
64+
}
65+
6266
public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
6367
// @formatter:off
6468
return new RSAKey.Builder(publicKey)

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
2323
import java.util.Collections;
2424
import java.util.List;
2525

26+
import com.nimbusds.jose.JWSAlgorithm;
2627
import com.nimbusds.jose.KeySourceException;
2728
import com.nimbusds.jose.jwk.ECKey;
2829
import com.nimbusds.jose.jwk.JWK;
@@ -39,6 +40,7 @@
3940
import org.mockito.invocation.InvocationOnMock;
4041
import org.mockito.stubbing.Answer;
4142

43+
import org.springframework.core.convert.converter.Converter;
4244
import org.springframework.security.oauth2.jose.TestJwks;
4345
import org.springframework.security.oauth2.jose.TestKeys;
4446
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -51,6 +53,8 @@
5153
import static org.mockito.BDDMockito.willAnswer;
5254
import static org.mockito.Mockito.mock;
5355
import static org.mockito.Mockito.spy;
56+
import static org.mockito.Mockito.verify;
57+
import static org.mockito.Mockito.verifyNoInteractions;
5458

5559
/**
5660
* Tests for {@link NimbusJwtEncoder}.
@@ -109,7 +113,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce
109113

110114
@Test
111115
public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
112-
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
116+
RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
113117
this.jwkList.add(rsaJwk);
114118
this.jwkList.add(rsaJwk);
115119

@@ -118,7 +122,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws
118122

119123
assertThatExceptionOfType(JwtEncodingException.class)
120124
.isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)))
121-
.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
125+
.withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]");
122126
}
123127

124128
@Test
@@ -291,6 +295,55 @@ public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) {
291295
assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID());
292296
}
293297

298+
@Test
299+
public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception {
300+
JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
301+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
302+
given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk));
303+
Converter<List<JWK>, JWK> selector = mock(Converter.class);
304+
given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK);
305+
306+
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
307+
jwtEncoder.setJwkSelector(selector);
308+
309+
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
310+
jwtEncoder.encode(JwtEncoderParameters.from(claims));
311+
312+
verify(selector).convert(any());
313+
}
314+
315+
@Test
316+
public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception {
317+
JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
318+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
319+
given(jwkSource.get(any(), any())).willReturn(List.of(jwk));
320+
Converter<List<JWK>, JWK> selector = mock(Converter.class);
321+
322+
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
323+
jwtEncoder.setJwkSelector(selector);
324+
325+
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
326+
jwtEncoder.encode(JwtEncoderParameters.from(claims));
327+
328+
verifyNoInteractions(selector);
329+
}
330+
331+
@Test
332+
public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception {
333+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
334+
given(jwkSource.get(any(), any())).willReturn(List.of());
335+
Converter<List<JWK>, JWK> selector = mock(Converter.class);
336+
337+
NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
338+
jwtEncoder.setJwkSelector(selector);
339+
340+
JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
341+
assertThatExceptionOfType(JwtEncodingException.class)
342+
.isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims)));
343+
344+
verifyNoInteractions(selector);
345+
}
346+
294347
private static final class JwkListResultCaptor implements Answer<List<JWK>> {
295348

296349
private List<JWK> result;

0 commit comments

Comments
 (0)