|
16 | 16 |
|
17 | 17 | package org.springframework.security.config.annotation.web.configurers.saml2;
|
18 | 18 |
|
| 19 | +import java.util.ArrayList; |
19 | 20 | import java.util.LinkedHashMap;
|
| 21 | +import java.util.List; |
20 | 22 | import java.util.Map;
|
21 | 23 |
|
| 24 | +import jakarta.servlet.http.HttpServletRequest; |
| 25 | + |
22 | 26 | import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
23 | 27 | import org.springframework.context.ApplicationContext;
|
24 | 28 | import org.springframework.security.authentication.AuthenticationManager;
|
|
33 | 37 | import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
|
34 | 38 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
35 | 39 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
| 40 | +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations; |
36 | 41 | import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
|
37 | 42 | import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter;
|
38 | 43 | import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
|
|
50 | 55 | import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
51 | 56 | import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
52 | 57 | import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
| 58 | +import org.springframework.security.web.util.matcher.ParameterRequestMatcher; |
53 | 59 | import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
|
54 | 60 | import org.springframework.security.web.util.matcher.RequestMatcher;
|
55 | 61 | import org.springframework.security.web.util.matcher.RequestMatchers;
|
@@ -113,6 +119,8 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
|
113 | 119 |
|
114 | 120 | private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI;
|
115 | 121 |
|
| 122 | + private String[] authenticationRequestParams = new String[0]; |
| 123 | + |
116 | 124 | private Saml2AuthenticationRequestResolver authenticationRequestResolver;
|
117 | 125 |
|
118 | 126 | private RequestMatcher loginProcessingUrl = RequestMatchers.anyOf(
|
@@ -196,11 +204,30 @@ public Saml2LoginConfigurer<B> authenticationRequestResolver(
|
196 | 204 | * Request
|
197 | 205 | * @return the {@link Saml2LoginConfigurer} for further configuration
|
198 | 206 | * @since 6.0
|
| 207 | + * @deprecated Use {@link #authenticationRequestUriQuery} instead |
199 | 208 | */
|
200 | 209 | public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri) {
|
201 |
| - Assert.state(authenticationRequestUri.contains("{registrationId}"), |
202 |
| - "authenticationRequestUri must contain {registrationId} path variable"); |
203 |
| - this.authenticationRequestUri = authenticationRequestUri; |
| 210 | + return authenticationRequestUriQuery(authenticationRequestUri); |
| 211 | + } |
| 212 | + |
| 213 | + /** |
| 214 | + * Customize the URL that the SAML Authentication Request will be sent to. |
| 215 | + * This method also supports query parameters like so: <pre> |
| 216 | + * authenticationRequestUriQuery("/saml/authenticate?registrationId={registrationId}") |
| 217 | + * </pre> |
| 218 | + * {@link RelyingPartyRegistrations} |
| 219 | + * @param authenticationRequestUriQuery the URI and query to use for the SAML 2.0 |
| 220 | + * Authentication Request |
| 221 | + * @return the {@link Saml2LoginConfigurer} for further configuration |
| 222 | + * @since 6.0 |
| 223 | + */ |
| 224 | + public Saml2LoginConfigurer<B> authenticationRequestUriQuery(String authenticationRequestUriQuery) { |
| 225 | + Assert.state(authenticationRequestUriQuery.contains("{registrationId}"), |
| 226 | + "authenticationRequestUri must contain {registrationId} path variable or query value"); |
| 227 | + String[] parts = authenticationRequestUriQuery.split("[?&]"); |
| 228 | + this.authenticationRequestUri = parts[0]; |
| 229 | + this.authenticationRequestParams = new String[parts.length - 1]; |
| 230 | + System.arraycopy(parts, 1, this.authenticationRequestParams, 0, parts.length - 1); |
204 | 231 | return this;
|
205 | 232 | }
|
206 | 233 |
|
@@ -255,7 +282,7 @@ public void init(B http) throws Exception {
|
255 | 282 | }
|
256 | 283 | else {
|
257 | 284 | Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
|
258 |
| - this.relyingPartyRegistrationRepository); |
| 285 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository); |
259 | 286 | boolean singleProvider = providerUrlMap.size() == 1;
|
260 | 287 | if (singleProvider) {
|
261 | 288 | // Setup auto-redirect to provider login page
|
@@ -336,8 +363,14 @@ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B ht
|
336 | 363 | }
|
337 | 364 | OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver(
|
338 | 365 | relyingPartyRegistrationRepository(http));
|
339 |
| - openSaml4AuthenticationRequestResolver |
340 |
| - .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); |
| 366 | + if (this.authenticationRequestParams.length > 0) { |
| 367 | + openSaml4AuthenticationRequestResolver.setRequestMatcher( |
| 368 | + new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams)); |
| 369 | + } |
| 370 | + else { |
| 371 | + openSaml4AuthenticationRequestResolver |
| 372 | + .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); |
| 373 | + } |
341 | 374 | return openSaml4AuthenticationRequestResolver;
|
342 | 375 | }
|
343 | 376 |
|
@@ -382,20 +415,28 @@ private void initDefaultLoginFilter(B http) {
|
382 | 415 | return;
|
383 | 416 | }
|
384 | 417 | loginPageGeneratingFilter.setSaml2LoginEnabled(true);
|
385 |
| - loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName( |
386 |
| - this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository)); |
| 418 | + loginPageGeneratingFilter |
| 419 | + .setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(this.authenticationRequestUri, |
| 420 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository)); |
387 | 421 | loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
|
388 | 422 | loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
|
389 | 423 | }
|
390 | 424 |
|
391 | 425 | @SuppressWarnings("unchecked")
|
392 |
| - private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, |
| 426 | + private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams, |
393 | 427 | RelyingPartyRegistrationRepository idpRepo) {
|
394 | 428 | Map<String, String> idps = new LinkedHashMap<>();
|
395 | 429 | if (idpRepo instanceof Iterable) {
|
396 | 430 | Iterable<RelyingPartyRegistration> repo = (Iterable<RelyingPartyRegistration>) idpRepo;
|
397 |
| - repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), |
398 |
| - p.getRegistrationId())); |
| 431 | + StringBuilder authRequestQuery = new StringBuilder("?"); |
| 432 | + for (String authRequestQueryParam : authRequestQueryParams) { |
| 433 | + authRequestQuery.append(authRequestQueryParam + "&"); |
| 434 | + } |
| 435 | + authRequestQuery.deleteCharAt(authRequestQuery.length() - 1); |
| 436 | + String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery; |
| 437 | + repo.forEach( |
| 438 | + (p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()), |
| 439 | + p.getRegistrationId())); |
399 | 440 | }
|
400 | 441 | return idps;
|
401 | 442 | }
|
@@ -437,4 +478,35 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
|
437 | 478 | }
|
438 | 479 | }
|
439 | 480 |
|
| 481 | + static class AntPathQueryRequestMatcher implements RequestMatcher { |
| 482 | + |
| 483 | + private final RequestMatcher matcher; |
| 484 | + |
| 485 | + AntPathQueryRequestMatcher(String path, String... params) { |
| 486 | + List<RequestMatcher> matchers = new ArrayList<>(); |
| 487 | + matchers.add(new AntPathRequestMatcher(path)); |
| 488 | + for (String param : params) { |
| 489 | + String[] parts = param.split("="); |
| 490 | + if (parts.length == 1) { |
| 491 | + matchers.add(new ParameterRequestMatcher(parts[0])); |
| 492 | + } |
| 493 | + else { |
| 494 | + matchers.add(new ParameterRequestMatcher(parts[0], parts[1])); |
| 495 | + } |
| 496 | + } |
| 497 | + this.matcher = new AndRequestMatcher(matchers); |
| 498 | + } |
| 499 | + |
| 500 | + @Override |
| 501 | + public boolean matches(HttpServletRequest request) { |
| 502 | + return matcher(request).isMatch(); |
| 503 | + } |
| 504 | + |
| 505 | + @Override |
| 506 | + public MatchResult matcher(HttpServletRequest request) { |
| 507 | + return this.matcher.matcher(request); |
| 508 | + } |
| 509 | + |
| 510 | + } |
| 511 | + |
440 | 512 | }
|
0 commit comments