11/*
2- * Copyright 2002-2022 the original author or authors.
2+ * Copyright 2002-2023 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2626import jakarta .servlet .DispatcherType ;
2727import jakarta .servlet .ServletContext ;
2828import jakarta .servlet .ServletRegistration ;
29+ import jakarta .servlet .http .HttpServletRequest ;
2930
3031import org .springframework .beans .factory .NoSuchBeanDefinitionException ;
3132import org .springframework .context .ApplicationContext ;
@@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
203204 if (!hasDispatcherServlet (registrations )) {
204205 return requestMatchers (RequestMatchers .antMatchersAsArray (method , patterns ));
205206 }
206- if (registrations .size () > 1 ) {
207- String errorMessage = computeErrorMessage (registrations .values ());
208- throw new IllegalArgumentException (errorMessage );
207+ ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet (registrations );
208+ if (dispatcherServlet != null ) {
209+ if (registrations .size () == 1 ) {
210+ return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
211+ }
212+ List <RequestMatcher > matchers = new ArrayList <>();
213+ for (String pattern : patterns ) {
214+ AntPathRequestMatcher ant = new AntPathRequestMatcher (pattern , (method != null ) ? method .name () : null );
215+ MvcRequestMatcher mvc = createMvcMatchers (method , pattern ).get (0 );
216+ matchers .add (new DispatcherServletDelegatingRequestMatcher (ant , mvc , servletContext ));
217+ }
218+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
209219 }
210- return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
220+ dispatcherServlet = requireOnlyPathMappedDispatcherServlet (registrations );
221+ if (dispatcherServlet != null ) {
222+ String mapping = dispatcherServlet .getMappings ().iterator ().next ();
223+ List <MvcRequestMatcher > matchers = createMvcMatchers (method , patterns );
224+ for (MvcRequestMatcher matcher : matchers ) {
225+ matcher .setServletPath (mapping .substring (0 , mapping .length () - 2 ));
226+ }
227+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
228+ }
229+ String errorMessage = computeErrorMessage (registrations .values ());
230+ throw new IllegalArgumentException (errorMessage );
211231 }
212232
213233 private Map <String , ? extends ServletRegistration > mappableServletRegistrations (ServletContext servletContext ) {
@@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
225245 if (registrations == null ) {
226246 return false ;
227247 }
228- Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
229- null );
230248 for (ServletRegistration registration : registrations .values ()) {
231- try {
232- Class <?> clazz = Class .forName (registration .getClassName ());
233- if (dispatcherServlet .isAssignableFrom (clazz )) {
234- return true ;
235- }
236- }
237- catch (ClassNotFoundException ex ) {
238- return false ;
249+ if (isDispatcherServlet (registration )) {
250+ return true ;
239251 }
240252 }
241253 return false ;
242254 }
243255
256+ private ServletRegistration requireOneRootDispatcherServlet (
257+ Map <String , ? extends ServletRegistration > registrations ) {
258+ ServletRegistration rootDispatcherServlet = null ;
259+ for (ServletRegistration registration : registrations .values ()) {
260+ if (!isDispatcherServlet (registration )) {
261+ continue ;
262+ }
263+ if (registration .getMappings ().size () > 1 ) {
264+ return null ;
265+ }
266+ if (!"/" .equals (registration .getMappings ().iterator ().next ())) {
267+ return null ;
268+ }
269+ rootDispatcherServlet = registration ;
270+ }
271+ return rootDispatcherServlet ;
272+ }
273+
274+ private ServletRegistration requireOnlyPathMappedDispatcherServlet (
275+ Map <String , ? extends ServletRegistration > registrations ) {
276+ ServletRegistration pathDispatcherServlet = null ;
277+ for (ServletRegistration registration : registrations .values ()) {
278+ if (!isDispatcherServlet (registration )) {
279+ return null ;
280+ }
281+ if (registration .getMappings ().size () > 1 ) {
282+ return null ;
283+ }
284+ String mapping = registration .getMappings ().iterator ().next ();
285+ if (!mapping .startsWith ("/" ) || !mapping .endsWith ("/*" )) {
286+ return null ;
287+ }
288+ if (pathDispatcherServlet != null ) {
289+ return null ;
290+ }
291+ pathDispatcherServlet = registration ;
292+ }
293+ return pathDispatcherServlet ;
294+ }
295+
296+ private boolean isDispatcherServlet (ServletRegistration registration ) {
297+ Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
298+ null );
299+ try {
300+ Class <?> clazz = Class .forName (registration .getClassName ());
301+ return dispatcherServlet .isAssignableFrom (clazz );
302+ }
303+ catch (ClassNotFoundException ex ) {
304+ return false ;
305+ }
306+ }
307+
244308 private String computeErrorMessage (Collection <? extends ServletRegistration > registrations ) {
245309 String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
246310 + "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
380444
381445 }
382446
447+ static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
448+
449+ private final AntPathRequestMatcher ant ;
450+
451+ private final MvcRequestMatcher mvc ;
452+
453+ private final ServletContext servletContext ;
454+
455+ DispatcherServletDelegatingRequestMatcher (AntPathRequestMatcher ant , MvcRequestMatcher mvc ,
456+ ServletContext servletContext ) {
457+ this .ant = ant ;
458+ this .mvc = mvc ;
459+ this .servletContext = servletContext ;
460+ }
461+
462+ @ Override
463+ public boolean matches (HttpServletRequest request ) {
464+ String name = request .getHttpServletMapping ().getServletName ();
465+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
466+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
467+ if (isDispatcherServlet (registration )) {
468+ return this .mvc .matches (request );
469+ }
470+ return this .ant .matches (request );
471+ }
472+
473+ @ Override
474+ public MatchResult matcher (HttpServletRequest request ) {
475+ String name = request .getHttpServletMapping ().getServletName ();
476+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
477+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
478+ if (isDispatcherServlet (registration )) {
479+ return this .mvc .matcher (request );
480+ }
481+ return this .ant .matcher (request );
482+ }
483+
484+ private boolean isDispatcherServlet (ServletRegistration registration ) {
485+ Class <?> dispatcherServlet = ClassUtils
486+ .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" , null );
487+ try {
488+ Class <?> clazz = Class .forName (registration .getClassName ());
489+ return dispatcherServlet .isAssignableFrom (clazz );
490+ }
491+ catch (ClassNotFoundException ex ) {
492+ return false ;
493+ }
494+ }
495+
496+ }
497+
383498}
0 commit comments