Skip to content

Commit b699094

Browse files
committed
Polish 'Choose SAML party based on entity ID rather than always using first'
See gh-35902
1 parent 864af59 commit b699094

File tree

2 files changed

+47
-47
lines changed

2 files changed

+47
-47
lines changed

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyRegistrationConfiguration.java

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.security.cert.CertificateFactory;
2121
import java.security.cert.X509Certificate;
2222
import java.security.interfaces.RSAPrivateKey;
23+
import java.util.Collection;
2324
import java.util.List;
2425
import java.util.Map;
2526
import java.util.function.Consumer;
@@ -63,6 +64,7 @@
6364
* @author Madhura Bhave
6465
* @author Phillip Webb
6566
* @author Moritz Halbritter
67+
* @author Lasse Lindqvist
6668
*/
6769
@Configuration(proxyBeanMethods = false)
6870
@Conditional(RegistrationConfiguredCondition.class)
@@ -88,14 +90,8 @@ private RelyingPartyRegistration asRegistration(Map.Entry<String, Registration>
8890
private RelyingPartyRegistration asRegistration(String id, Registration properties) {
8991
AssertingPartyProperties assertingParty = new AssertingPartyProperties(properties, id);
9092
boolean usingMetadata = StringUtils.hasText(assertingParty.getMetadataUri());
91-
Builder builder = (usingMetadata) ? RelyingPartyRegistrations
92-
.collectionFromMetadataLocation(properties.getAssertingparty().getMetadataUri())
93-
.stream()
94-
.filter(b -> entityIdsMatch(properties, b))
95-
.findFirst()
96-
.orElseThrow(() -> new IllegalStateException(
97-
"No relying party with entity-id " + properties.getEntityId() + " found."))
98-
.registrationId(id) : RelyingPartyRegistration.withRegistrationId(id);
93+
Builder builder = (!usingMetadata) ? RelyingPartyRegistration.withRegistrationId(id)
94+
: createBuilderUsingMetadata(id, assertingParty).registrationId(id);
9995
builder.assertionConsumerServiceLocation(properties.getAcs().getLocation());
10096
builder.assertionConsumerServiceBinding(properties.getAcs().getBinding());
10197
builder.assertingPartyDetails(mapAssertingParty(properties, id, usingMetadata));
@@ -124,17 +120,23 @@ private RelyingPartyRegistration asRegistration(String id, Registration properti
124120
return registration;
125121
}
126122

127-
/**
128-
* Tests if the builder would have the correct entity-id. If no entity-id is given in
129-
* properties, any builder passes the test.
130-
* @param properties the properties
131-
* @param b the builder
132-
* @return true if the builder passes the test
133-
*/
134-
private boolean entityIdsMatch(Registration properties, Builder b) {
135-
RelyingPartyRegistration rpr = b.build();
136-
return properties.getAssertingparty().getEntityId() == null
137-
|| properties.getAssertingparty().getEntityId().equals(rpr.getAssertingPartyDetails().getEntityId());
123+
private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String id,
124+
AssertingPartyProperties properties) {
125+
String requiredEntityId = properties.getEntityId();
126+
Collection<Builder> candidates = RelyingPartyRegistrations
127+
.collectionFromMetadataLocation(properties.getMetadataUri());
128+
for (RelyingPartyRegistration.Builder candidate : candidates) {
129+
if (requiredEntityId == null || requiredEntityId.equals(getEntityId(candidate))) {
130+
return candidate;
131+
}
132+
}
133+
throw new IllegalStateException("No relying party with Entity ID '" + requiredEntityId + "' found");
134+
}
135+
136+
private Object getEntityId(RelyingPartyRegistration.Builder candidate) {
137+
String[] result = new String[1];
138+
candidate.assertingPartyDetails((builder) -> result[0] = builder.build().getEntityId());
139+
return result[0];
138140
}
139141

140142
private Consumer<AssertingPartyDetails.Builder> mapAssertingParty(Registration registration, String id,

spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/saml2/Saml2RelyingPartyAutoConfigurationTests.java

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.boot.autoconfigure.security.saml2;
1818

19+
import java.io.IOException;
1920
import java.io.InputStream;
2021
import java.util.List;
2122

@@ -55,6 +56,7 @@
5556
*
5657
* @author Madhura Bhave
5758
* @author Moritz Halbritter
59+
* @author Lasse Lindqvist
5860
*/
5961
class Saml2RelyingPartyAutoConfigurationTests {
6062

@@ -402,41 +404,37 @@ void samlLogoutShouldBeConfigured() {
402404
this.contextRunner.withPropertyValues(getPropertyValues(false))
403405
.run((context) -> assertThat(hasFilter(context, Saml2LogoutRequestFilter.class)).isTrue());
404406
}
405-
407+
406408
@Test
407-
void autoconfigurationShouldWorkWithMultipleProvidersWithNoEntityId() throws Exception {
408-
try (MockWebServer server = new MockWebServer()) {
409-
server.start();
410-
String metadataUrl = server.url("").toString();
411-
setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers"));
412-
this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl)
413-
.run((context) -> {
414-
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
415-
assertThat(server.getRequestCount()).isOne();
416-
RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class);
417-
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
418-
assertThat(registration.getAssertingPartyDetails().getEntityId())
419-
.isEqualTo("https://idp.example.com/idp/shibboleth");
420-
});
421-
}
409+
void autoconfigurationWhenMultipleProvidersAndNoSpecifiedEntityId() throws Exception {
410+
testMultipleProviders(null, "https://idp.example.com/idp/shibboleth");
422411
}
423-
412+
424413
@Test
425-
void autoconfigurationShouldWorkWithMultipleProviders() throws Exception {
414+
void autoconfigurationWhenMultipleProvidersAndSpecifiedEntityId() throws Exception {
415+
testMultipleProviders("https://idp.example.com/idp/shibboleth", "https://idp.example.com/idp/shibboleth");
416+
testMultipleProviders("https://idp2.example.com/idp/shibboleth", "https://idp2.example.com/idp/shibboleth");
417+
}
418+
419+
private void testMultipleProviders(String specifiedEntityId, String expected) throws IOException, Exception {
426420
try (MockWebServer server = new MockWebServer()) {
427421
server.start();
428422
String metadataUrl = server.url("").toString();
429423
setupMockResponse(server, new ClassPathResource("saml/idp-metadata-with-multiple-providers"));
430-
this.contextRunner.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl,
431-
PREFIX + ".foo.assertingparty.entity-id=https://idp2.example.com/idp/shibboleth")
432-
.run((context) -> {
433-
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
434-
assertThat(server.getRequestCount()).isOne();
435-
RelyingPartyRegistrationRepository repository = context.getBean(RelyingPartyRegistrationRepository.class);
436-
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
437-
assertThat(registration.getAssertingPartyDetails().getEntityId())
438-
.isEqualTo("https://idp2.example.com/idp/shibboleth");
439-
});
424+
WebApplicationContextRunner contextRunner = this.contextRunner
425+
.withPropertyValues(PREFIX + ".foo.assertingparty.metadata-uri=" + metadataUrl);
426+
if (specifiedEntityId != null) {
427+
contextRunner = contextRunner
428+
.withPropertyValues(PREFIX + ".foo.assertingparty.entity-id=" + specifiedEntityId);
429+
}
430+
contextRunner.run((context) -> {
431+
assertThat(context).hasSingleBean(RelyingPartyRegistrationRepository.class);
432+
assertThat(server.getRequestCount()).isOne();
433+
RelyingPartyRegistrationRepository repository = context
434+
.getBean(RelyingPartyRegistrationRepository.class);
435+
RelyingPartyRegistration registration = repository.findByRegistrationId("foo");
436+
assertThat(registration.getAssertingPartyDetails().getEntityId()).isEqualTo(expected);
437+
});
440438
}
441439
}
442440

0 commit comments

Comments
 (0)