Skip to content

Commit 37b893a

Browse files
committed
Extract Placeholder Resolution
Closes gh-12842
1 parent 42cece2 commit 37b893a

File tree

3 files changed

+195
-61
lines changed

3 files changed

+195
-61
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,17 @@
1616

1717
package org.springframework.security.saml2.provider.service.web;
1818

19-
import java.util.HashMap;
20-
import java.util.Map;
21-
import java.util.function.Function;
22-
2319
import jakarta.servlet.http.HttpServletRequest;
2420
import org.apache.commons.logging.Log;
2521
import org.apache.commons.logging.LogFactory;
2622

2723
import org.springframework.core.convert.converter.Converter;
2824
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
2925
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
30-
import org.springframework.security.web.util.UrlUtils;
26+
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
3127
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
3228
import org.springframework.security.web.util.matcher.RequestMatcher;
3329
import org.springframework.util.Assert;
34-
import org.springframework.util.StringUtils;
35-
import org.springframework.web.util.UriComponents;
36-
import org.springframework.web.util.UriComponentsBuilder;
3730

3831
/**
3932
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
@@ -48,8 +41,6 @@ public final class DefaultRelyingPartyRegistrationResolver
4841

4942
private Log logger = LogFactory.getLog(getClass());
5043

51-
private static final char PATH_DELIMITER = '/';
52-
5344
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
5445

5546
private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
@@ -87,61 +78,19 @@ public RelyingPartyRegistration resolve(HttpServletRequest request, String relyi
8778
}
8879
return null;
8980
}
90-
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
81+
RelyingPartyRegistration registration = this.relyingPartyRegistrationRepository
9182
.findByRegistrationId(relyingPartyRegistrationId);
92-
if (relyingPartyRegistration == null) {
83+
if (registration == null) {
9384
return null;
9485
}
95-
String applicationUri = getApplicationUri(request);
96-
Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
97-
String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
98-
String assertionConsumerServiceLocation = templateResolver
99-
.apply(relyingPartyRegistration.getAssertionConsumerServiceLocation());
100-
String singleLogoutServiceLocation = templateResolver
101-
.apply(relyingPartyRegistration.getSingleLogoutServiceLocation());
102-
String singleLogoutServiceResponseLocation = templateResolver
103-
.apply(relyingPartyRegistration.getSingleLogoutServiceResponseLocation());
104-
return relyingPartyRegistration.mutate().entityId(relyingPartyEntityId)
105-
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
106-
.singleLogoutServiceLocation(singleLogoutServiceLocation)
107-
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation).build();
108-
}
109-
110-
private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
111-
return (template) -> resolveUrlTemplate(template, applicationUri, relyingParty);
112-
}
113-
114-
private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
115-
if (template == null) {
116-
return null;
117-
}
118-
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
119-
String registrationId = relyingParty.getRegistrationId();
120-
Map<String, String> uriVariables = new HashMap<>();
121-
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null)
86+
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
87+
return registration.mutate().entityId(uriResolver.resolve(registration.getEntityId()))
88+
.assertionConsumerServiceLocation(
89+
uriResolver.resolve(registration.getAssertionConsumerServiceLocation()))
90+
.singleLogoutServiceLocation(uriResolver.resolve(registration.getSingleLogoutServiceLocation()))
91+
.singleLogoutServiceResponseLocation(
92+
uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()))
12293
.build();
123-
String scheme = uriComponents.getScheme();
124-
uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
125-
String host = uriComponents.getHost();
126-
uriVariables.put("baseHost", (host != null) ? host : "");
127-
// following logic is based on HierarchicalUriComponents#toUriString()
128-
int port = uriComponents.getPort();
129-
uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
130-
String path = uriComponents.getPath();
131-
if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
132-
path = PATH_DELIMITER + path;
133-
}
134-
uriVariables.put("basePath", (path != null) ? path : "");
135-
uriVariables.put("baseUrl", uriComponents.toUriString());
136-
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
137-
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
138-
return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString();
139-
}
140-
141-
private static String getApplicationUri(HttpServletRequest request) {
142-
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
143-
.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
144-
return uriComponents.toUriString();
14594
}
14695

14796
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
import jakarta.servlet.http.HttpServletRequest;
23+
24+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
25+
import org.springframework.security.web.util.UrlUtils;
26+
import org.springframework.util.StringUtils;
27+
import org.springframework.web.util.UriComponents;
28+
import org.springframework.web.util.UriComponentsBuilder;
29+
30+
/**
31+
* A factory for creating placeholder resolvers for {@link RelyingPartyRegistration}
32+
* templates. Supports {@code baseUrl}, {@code baseScheme}, {@code baseHost},
33+
* {@code basePort}, {@code basePath}, {@code registrationId},
34+
* {@code relyingPartyEntityId}, and {@code assertingPartyEntityId}
35+
*
36+
* @author Josh Cummings
37+
* @since 6.1
38+
*/
39+
public final class RelyingPartyRegistrationPlaceholderResolvers {
40+
41+
private static final char PATH_DELIMITER = '/';
42+
43+
private RelyingPartyRegistrationPlaceholderResolvers() {
44+
45+
}
46+
47+
/**
48+
* Create a resolver based on the given {@link HttpServletRequest}. Given the request,
49+
* placeholders {@code baseUrl}, {@code baseScheme}, {@code baseHost},
50+
* {@code basePort}, and {@code basePath} are resolved.
51+
* @param request the HTTP request
52+
* @return a resolver that can resolve {@code baseUrl}, {@code baseScheme},
53+
* {@code baseHost}, {@code basePort}, and {@code basePath} placeholders
54+
*/
55+
public static UriResolver uriResolver(HttpServletRequest request) {
56+
return new UriResolver(uriVariables(request));
57+
}
58+
59+
/**
60+
* Create a resolver based on the given {@link HttpServletRequest}. Given the request,
61+
* placeholders {@code baseUrl}, {@code baseScheme}, {@code baseHost},
62+
* {@code basePort}, {@code basePath}, {@code registrationId},
63+
* {@code assertingPartyEntityId}, and {@code relyingPartyEntityId} are resolved.
64+
* @param request the HTTP request
65+
* @return a resolver that can resolve {@code baseUrl}, {@code baseScheme},
66+
* {@code baseHost}, {@code basePort}, {@code basePath}, {@code registrationId},
67+
* {@code relyingPartyEntityId}, and {@code assertingPartyEntityId} placeholders
68+
*/
69+
public static UriResolver uriResolver(HttpServletRequest request, RelyingPartyRegistration registration) {
70+
String relyingPartyEntityId = registration.getEntityId();
71+
String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
72+
String registrationId = registration.getRegistrationId();
73+
Map<String, String> uriVariables = uriVariables(request);
74+
uriVariables.put("relyingPartyEntityId", StringUtils.hasText(relyingPartyEntityId) ? relyingPartyEntityId : "");
75+
uriVariables.put("assertingPartyEntityId",
76+
StringUtils.hasText(assertingPartyEntityId) ? assertingPartyEntityId : "");
77+
uriVariables.put("entityId", StringUtils.hasText(assertingPartyEntityId) ? assertingPartyEntityId : "");
78+
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
79+
return new UriResolver(uriVariables);
80+
}
81+
82+
private static Map<String, String> uriVariables(HttpServletRequest request) {
83+
String baseUrl = getApplicationUri(request);
84+
Map<String, String> uriVariables = new HashMap<>();
85+
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null)
86+
.build();
87+
String scheme = uriComponents.getScheme();
88+
uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
89+
String host = uriComponents.getHost();
90+
uriVariables.put("baseHost", (host != null) ? host : "");
91+
// following logic is based on HierarchicalUriComponents#toUriString()
92+
int port = uriComponents.getPort();
93+
uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
94+
String path = uriComponents.getPath();
95+
if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
96+
path = PATH_DELIMITER + path;
97+
}
98+
uriVariables.put("basePath", (path != null) ? path : "");
99+
uriVariables.put("baseUrl", uriComponents.toUriString());
100+
return uriVariables;
101+
}
102+
103+
private static String getApplicationUri(HttpServletRequest request) {
104+
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
105+
.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
106+
return uriComponents.toUriString();
107+
}
108+
109+
/**
110+
* A class for resolving {@link RelyingPartyRegistration} URIs
111+
*/
112+
public static final class UriResolver {
113+
114+
private final Map<String, String> uriVariables;
115+
116+
private UriResolver(Map<String, String> uriVariables) {
117+
this.uriVariables = uriVariables;
118+
}
119+
120+
public String resolve(String uri) {
121+
if (uri == null) {
122+
return null;
123+
}
124+
return UriComponentsBuilder.fromUriString(uri).buildAndExpand(this.uriVariables).toUriString();
125+
}
126+
127+
}
128+
129+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import org.springframework.mock.web.MockHttpServletRequest;
22+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
23+
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
24+
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
28+
29+
/**
30+
* Tests for {@link RelyingPartyRegistrationPlaceholderResolvers}
31+
*/
32+
public class RelyingPartyRegistrationPlaceholderResolversTests {
33+
34+
@Test
35+
void uriResolverGivenRequestCreatesResolver() {
36+
MockHttpServletRequest request = new MockHttpServletRequest();
37+
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request);
38+
String resolved = uriResolver.resolve("{baseUrl}/extension");
39+
assertThat(resolved).isEqualTo("http://localhost/extension");
40+
assertThatExceptionOfType(IllegalArgumentException.class)
41+
.isThrownBy(() -> uriResolver.resolve("{baseUrl}/extension/{registrationId}"));
42+
}
43+
44+
@Test
45+
void uriResolverGivenRequestAndRegistrationCreatesResolver() {
46+
MockHttpServletRequest request = new MockHttpServletRequest();
47+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
48+
.entityId("http://sp.example.org").build();
49+
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
50+
String resolved = uriResolver.resolve("{baseUrl}/extension/{registrationId}");
51+
assertThat(resolved).isEqualTo("http://localhost/extension/simplesamlphp");
52+
resolved = uriResolver.resolve("{relyingPartyEntityId}/extension");
53+
assertThat(resolved).isEqualTo("http://sp.example.org/extension");
54+
}
55+
56+
}

0 commit comments

Comments
 (0)