Skip to content

Commit ffd12ee

Browse files
committed
Refine requestMatcher Validation Rules
Closes gh-14078
1 parent 3f64c6d commit ffd12ee

File tree

4 files changed

+292
-17
lines changed

4 files changed

+292
-17
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

Lines changed: 130 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
import javax.servlet.DispatcherType;
2727
import javax.servlet.ServletContext;
2828
import javax.servlet.ServletRegistration;
29+
import javax.servlet.http.HttpServletRequest;
2930

3031
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
3132
import org.springframework.context.ApplicationContext;
@@ -321,11 +322,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
321322
if (!hasDispatcherServlet(registrations)) {
322323
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
323324
}
324-
if (registrations.size() > 1) {
325-
String errorMessage = computeErrorMessage(registrations.values());
326-
throw new IllegalArgumentException(errorMessage);
325+
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
326+
if (dispatcherServlet != null) {
327+
if (registrations.size() == 1) {
328+
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
329+
}
330+
List<RequestMatcher> matchers = new ArrayList<>();
331+
for (String pattern : patterns) {
332+
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
333+
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
334+
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
335+
}
336+
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
327337
}
328-
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
338+
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
339+
if (dispatcherServlet != null) {
340+
String mapping = dispatcherServlet.getMappings().iterator().next();
341+
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
342+
for (MvcRequestMatcher matcher : matchers) {
343+
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
344+
}
345+
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
346+
}
347+
String errorMessage = computeErrorMessage(registrations.values());
348+
throw new IllegalArgumentException(errorMessage);
329349
}
330350

331351
private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
@@ -343,22 +363,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
343363
if (registrations == null) {
344364
return false;
345365
}
346-
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
347-
null);
348366
for (ServletRegistration registration : registrations.values()) {
349-
try {
350-
Class<?> clazz = Class.forName(registration.getClassName());
351-
if (dispatcherServlet.isAssignableFrom(clazz)) {
352-
return true;
353-
}
354-
}
355-
catch (ClassNotFoundException ex) {
356-
return false;
367+
if (isDispatcherServlet(registration)) {
368+
return true;
357369
}
358370
}
359371
return false;
360372
}
361373

374+
private ServletRegistration requireOneRootDispatcherServlet(
375+
Map<String, ? extends ServletRegistration> registrations) {
376+
ServletRegistration rootDispatcherServlet = null;
377+
for (ServletRegistration registration : registrations.values()) {
378+
if (!isDispatcherServlet(registration)) {
379+
continue;
380+
}
381+
if (registration.getMappings().size() > 1) {
382+
return null;
383+
}
384+
if (!"/".equals(registration.getMappings().iterator().next())) {
385+
return null;
386+
}
387+
rootDispatcherServlet = registration;
388+
}
389+
return rootDispatcherServlet;
390+
}
391+
392+
private ServletRegistration requireOnlyPathMappedDispatcherServlet(
393+
Map<String, ? extends ServletRegistration> registrations) {
394+
ServletRegistration pathDispatcherServlet = null;
395+
for (ServletRegistration registration : registrations.values()) {
396+
if (!isDispatcherServlet(registration)) {
397+
return null;
398+
}
399+
if (registration.getMappings().size() > 1) {
400+
return null;
401+
}
402+
String mapping = registration.getMappings().iterator().next();
403+
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
404+
return null;
405+
}
406+
if (pathDispatcherServlet != null) {
407+
return null;
408+
}
409+
pathDispatcherServlet = registration;
410+
}
411+
return pathDispatcherServlet;
412+
}
413+
414+
private boolean isDispatcherServlet(ServletRegistration registration) {
415+
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
416+
null);
417+
try {
418+
Class<?> clazz = Class.forName(registration.getClassName());
419+
return dispatcherServlet.isAssignableFrom(clazz);
420+
}
421+
catch (ClassNotFoundException ex) {
422+
return false;
423+
}
424+
}
425+
362426
private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
363427
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
364428
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -498,4 +562,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
498562

499563
}
500564

565+
static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
566+
567+
private final AntPathRequestMatcher ant;
568+
569+
private final MvcRequestMatcher mvc;
570+
571+
private final ServletContext servletContext;
572+
573+
DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
574+
ServletContext servletContext) {
575+
this.ant = ant;
576+
this.mvc = mvc;
577+
this.servletContext = servletContext;
578+
}
579+
580+
@Override
581+
public boolean matches(HttpServletRequest request) {
582+
String name = request.getHttpServletMapping().getServletName();
583+
ServletRegistration registration = this.servletContext.getServletRegistration(name);
584+
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
585+
if (isDispatcherServlet(registration)) {
586+
return this.mvc.matches(request);
587+
}
588+
return this.ant.matches(request);
589+
}
590+
591+
@Override
592+
public MatchResult matcher(HttpServletRequest request) {
593+
String name = request.getHttpServletMapping().getServletName();
594+
ServletRegistration registration = this.servletContext.getServletRegistration(name);
595+
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
596+
if (isDispatcherServlet(registration)) {
597+
return this.mvc.matcher(request);
598+
}
599+
return this.ant.matcher(request);
600+
}
601+
602+
private boolean isDispatcherServlet(ServletRegistration registration) {
603+
Class<?> dispatcherServlet = ClassUtils
604+
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
605+
try {
606+
Class<?> clazz = Class.forName(registration.getClassName());
607+
return dispatcherServlet.isAssignableFrom(clazz);
608+
}
609+
catch (ClassNotFoundException ex) {
610+
return false;
611+
}
612+
}
613+
614+
}
615+
501616
}

config/src/test/java/org/springframework/security/config/MockServletContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class
5555
return this.registrations;
5656
}
5757

58+
@Override
59+
public ServletRegistration getServletRegistration(String servletName) {
60+
return this.registrations.get(servletName);
61+
}
62+
5863
private static class MockServletRegistration implements ServletRegistration.Dynamic {
5964

6065
private final String name;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config;
18+
19+
import javax.servlet.http.HttpServletRequest;
20+
import javax.servlet.http.MappingMatch;
21+
22+
import org.springframework.mock.web.MockHttpServletMapping;
23+
24+
public final class TestMockHttpServletMappings {
25+
26+
private TestMockHttpServletMappings() {
27+
28+
}
29+
30+
public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
31+
String uri = request.getRequestURI();
32+
String matchValue = uri.substring(0, uri.lastIndexOf(extension));
33+
return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
34+
}
35+
36+
public static MockHttpServletMapping path(HttpServletRequest request, String path) {
37+
String uri = request.getRequestURI();
38+
String matchValue = uri.substring(path.length());
39+
return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
40+
}
41+
42+
public static MockHttpServletMapping defaultMapping() {
43+
return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
44+
}
45+
46+
}

0 commit comments

Comments
 (0)