Skip to content

Commit 015281f

Browse files
committed
Add DefaultRelyingPartyRegistrationResolver
Closes gh-8887
1 parent a402c38 commit 015281f

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
import java.util.function.Function;
22+
import javax.servlet.http.HttpServletRequest;
23+
24+
import org.springframework.core.convert.converter.Converter;
25+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
26+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
27+
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
28+
import org.springframework.security.web.util.matcher.RequestMatcher;
29+
import org.springframework.util.Assert;
30+
import org.springframework.util.StringUtils;
31+
import org.springframework.web.util.UriComponents;
32+
import org.springframework.web.util.UriComponentsBuilder;
33+
34+
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
35+
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
36+
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
37+
38+
/**
39+
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
40+
* registration id from the request, querying a {@link RelyingPartyRegistrationRepository},
41+
* and resolving any template values.
42+
*
43+
* @since 5.4
44+
* @author Josh Cummings
45+
*/
46+
public final class DefaultRelyingPartyRegistrationResolver
47+
implements Converter<HttpServletRequest, RelyingPartyRegistration> {
48+
49+
private static final char PATH_DELIMITER = '/';
50+
51+
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
52+
private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver();
53+
54+
public DefaultRelyingPartyRegistrationResolver
55+
(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
56+
57+
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
58+
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
59+
}
60+
61+
@Override
62+
public RelyingPartyRegistration convert(HttpServletRequest request) {
63+
String registrationId = this.registrationIdResolver.convert(request);
64+
if (registrationId == null) {
65+
return null;
66+
}
67+
RelyingPartyRegistration relyingPartyRegistration =
68+
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
69+
if (relyingPartyRegistration == null) {
70+
return null;
71+
}
72+
73+
String applicationUri = getApplicationUri(request);
74+
Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
75+
String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
76+
String assertionConsumerServiceLocation = templateResolver.apply(
77+
relyingPartyRegistration.getAssertionConsumerServiceLocation());
78+
return withRelyingPartyRegistration(relyingPartyRegistration)
79+
.entityId(relyingPartyEntityId)
80+
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
81+
.build();
82+
}
83+
84+
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
85+
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
86+
}
87+
88+
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
89+
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
90+
String registrationId = relyingParty.getRegistrationId();
91+
Map<String, String> uriVariables = new HashMap<>();
92+
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
93+
.replaceQuery(null)
94+
.fragment(null)
95+
.build();
96+
String scheme = uriComponents.getScheme();
97+
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
98+
String host = uriComponents.getHost();
99+
uriVariables.put("baseHost", host == null ? "" : host);
100+
// following logic is based on HierarchicalUriComponents#toUriString()
101+
int port = uriComponents.getPort();
102+
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
103+
String path = uriComponents.getPath();
104+
if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
105+
path = PATH_DELIMITER + path;
106+
}
107+
uriVariables.put("basePath", path == null ? "" : path);
108+
uriVariables.put("baseUrl", uriComponents.toUriString());
109+
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
110+
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
111+
112+
return UriComponentsBuilder.fromUriString(template)
113+
.buildAndExpand(uriVariables)
114+
.toUriString();
115+
}
116+
117+
private static String getApplicationUri(HttpServletRequest request) {
118+
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
119+
.replacePath(request.getContextPath())
120+
.replaceQuery(null)
121+
.fragment(null)
122+
.build();
123+
return uriComponents.toUriString();
124+
}
125+
126+
private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {
127+
private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
128+
129+
@Override
130+
public String convert(HttpServletRequest request) {
131+
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
132+
return result.getVariables().get("registrationId");
133+
}
134+
}
135+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import org.junit.Test;
20+
21+
import org.springframework.mock.web.MockHttpServletRequest;
22+
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
23+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
24+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
import static org.assertj.core.api.Assertions.assertThatCode;
28+
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
29+
30+
/**
31+
* Tests for {@link DefaultRelyingPartyRegistrationResolver}
32+
*/
33+
public class DefaultRelyingPartyRegistrationResolverTests {
34+
private final RelyingPartyRegistration registration = relyingPartyRegistration().build();
35+
private final RelyingPartyRegistrationRepository repository =
36+
new InMemoryRelyingPartyRegistrationRepository(this.registration);
37+
private final DefaultRelyingPartyRegistrationResolver resolver =
38+
new DefaultRelyingPartyRegistrationResolver(this.repository);
39+
40+
@Test
41+
public void resolveWhenRequestContainsRegistrationIdThenResolves() {
42+
MockHttpServletRequest request = new MockHttpServletRequest();
43+
request.setPathInfo("/some/path/" + this.registration.getRegistrationId());
44+
RelyingPartyRegistration registration = this.resolver.convert(request);
45+
assertThat(registration).isNotNull();
46+
assertThat(registration.getRegistrationId())
47+
.isEqualTo(this.registration.getRegistrationId());
48+
assertThat(registration.getEntityId())
49+
.isEqualTo("http://localhost/saml2/service-provider-metadata/" + this.registration.getRegistrationId());
50+
assertThat(registration.getAssertionConsumerServiceLocation())
51+
.isEqualTo("http://localhost/login/saml2/sso/" + this.registration.getRegistrationId());
52+
}
53+
54+
@Test
55+
public void resolveWhenRequestContainsInvalidRegistrationIdThenNull() {
56+
MockHttpServletRequest request = new MockHttpServletRequest();
57+
request.setPathInfo("/some/path/not-" + this.registration.getRegistrationId());
58+
RelyingPartyRegistration registration = this.resolver.convert(request);
59+
assertThat(registration).isNull();
60+
}
61+
62+
@Test
63+
public void resolveWhenRequestIsMissingRegistrationIdThenNull() {
64+
MockHttpServletRequest request = new MockHttpServletRequest();
65+
RelyingPartyRegistration registration = this.resolver.convert(request);
66+
assertThat(registration).isNull();
67+
}
68+
69+
@Test
70+
public void constructorWhenNullRelyingPartyRegistrationThenIllegalArgument() {
71+
assertThatCode(() -> new DefaultRelyingPartyRegistrationResolver(null))
72+
.isInstanceOf(IllegalArgumentException.class);
73+
}
74+
}

0 commit comments

Comments
 (0)