11/*
2- * Copyright 2002-2022 the original author or authors.
2+ * Copyright 2002-2023 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2626import javax .servlet .DispatcherType ;
2727import javax .servlet .ServletContext ;
2828import javax .servlet .ServletRegistration ;
29+ import javax .servlet .http .HttpServletRequest ;
2930
3031import org .springframework .beans .factory .NoSuchBeanDefinitionException ;
3132import org .springframework .context .ApplicationContext ;
@@ -321,11 +322,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
321322 if (!hasDispatcherServlet (registrations )) {
322323 return requestMatchers (RequestMatchers .antMatchersAsArray (method , patterns ));
323324 }
324- if (registrations .size () > 1 ) {
325- String errorMessage = computeErrorMessage (registrations .values ());
326- throw new IllegalArgumentException (errorMessage );
325+ ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet (registrations );
326+ if (dispatcherServlet != null ) {
327+ if (registrations .size () == 1 ) {
328+ return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
329+ }
330+ List <RequestMatcher > matchers = new ArrayList <>();
331+ for (String pattern : patterns ) {
332+ AntPathRequestMatcher ant = new AntPathRequestMatcher (pattern , (method != null ) ? method .name () : null );
333+ MvcRequestMatcher mvc = createMvcMatchers (method , pattern ).get (0 );
334+ matchers .add (new DispatcherServletDelegatingRequestMatcher (ant , mvc , servletContext ));
335+ }
336+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
327337 }
328- return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
338+ dispatcherServlet = requireOnlyPathMappedDispatcherServlet (registrations );
339+ if (dispatcherServlet != null ) {
340+ String mapping = dispatcherServlet .getMappings ().iterator ().next ();
341+ List <MvcRequestMatcher > matchers = createMvcMatchers (method , patterns );
342+ for (MvcRequestMatcher matcher : matchers ) {
343+ matcher .setServletPath (mapping .substring (0 , mapping .length () - 2 ));
344+ }
345+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
346+ }
347+ String errorMessage = computeErrorMessage (registrations .values ());
348+ throw new IllegalArgumentException (errorMessage );
329349 }
330350
331351 private Map <String , ? extends ServletRegistration > mappableServletRegistrations (ServletContext servletContext ) {
@@ -343,22 +363,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
343363 if (registrations == null ) {
344364 return false ;
345365 }
346- Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
347- null );
348366 for (ServletRegistration registration : registrations .values ()) {
349- try {
350- Class <?> clazz = Class .forName (registration .getClassName ());
351- if (dispatcherServlet .isAssignableFrom (clazz )) {
352- return true ;
353- }
354- }
355- catch (ClassNotFoundException ex ) {
356- return false ;
367+ if (isDispatcherServlet (registration )) {
368+ return true ;
357369 }
358370 }
359371 return false ;
360372 }
361373
374+ private ServletRegistration requireOneRootDispatcherServlet (
375+ Map <String , ? extends ServletRegistration > registrations ) {
376+ ServletRegistration rootDispatcherServlet = null ;
377+ for (ServletRegistration registration : registrations .values ()) {
378+ if (!isDispatcherServlet (registration )) {
379+ continue ;
380+ }
381+ if (registration .getMappings ().size () > 1 ) {
382+ return null ;
383+ }
384+ if (!"/" .equals (registration .getMappings ().iterator ().next ())) {
385+ return null ;
386+ }
387+ rootDispatcherServlet = registration ;
388+ }
389+ return rootDispatcherServlet ;
390+ }
391+
392+ private ServletRegistration requireOnlyPathMappedDispatcherServlet (
393+ Map <String , ? extends ServletRegistration > registrations ) {
394+ ServletRegistration pathDispatcherServlet = null ;
395+ for (ServletRegistration registration : registrations .values ()) {
396+ if (!isDispatcherServlet (registration )) {
397+ return null ;
398+ }
399+ if (registration .getMappings ().size () > 1 ) {
400+ return null ;
401+ }
402+ String mapping = registration .getMappings ().iterator ().next ();
403+ if (!mapping .startsWith ("/" ) || !mapping .endsWith ("/*" )) {
404+ return null ;
405+ }
406+ if (pathDispatcherServlet != null ) {
407+ return null ;
408+ }
409+ pathDispatcherServlet = registration ;
410+ }
411+ return pathDispatcherServlet ;
412+ }
413+
414+ private boolean isDispatcherServlet (ServletRegistration registration ) {
415+ Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
416+ null );
417+ try {
418+ Class <?> clazz = Class .forName (registration .getClassName ());
419+ return dispatcherServlet .isAssignableFrom (clazz );
420+ }
421+ catch (ClassNotFoundException ex ) {
422+ return false ;
423+ }
424+ }
425+
362426 private String computeErrorMessage (Collection <? extends ServletRegistration > registrations ) {
363427 String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
364428 + "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -498,4 +562,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
498562
499563 }
500564
565+ static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
566+
567+ private final AntPathRequestMatcher ant ;
568+
569+ private final MvcRequestMatcher mvc ;
570+
571+ private final ServletContext servletContext ;
572+
573+ DispatcherServletDelegatingRequestMatcher (AntPathRequestMatcher ant , MvcRequestMatcher mvc ,
574+ ServletContext servletContext ) {
575+ this .ant = ant ;
576+ this .mvc = mvc ;
577+ this .servletContext = servletContext ;
578+ }
579+
580+ @ Override
581+ public boolean matches (HttpServletRequest request ) {
582+ String name = request .getHttpServletMapping ().getServletName ();
583+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
584+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
585+ if (isDispatcherServlet (registration )) {
586+ return this .mvc .matches (request );
587+ }
588+ return this .ant .matches (request );
589+ }
590+
591+ @ Override
592+ public MatchResult matcher (HttpServletRequest request ) {
593+ String name = request .getHttpServletMapping ().getServletName ();
594+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
595+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
596+ if (isDispatcherServlet (registration )) {
597+ return this .mvc .matcher (request );
598+ }
599+ return this .ant .matcher (request );
600+ }
601+
602+ private boolean isDispatcherServlet (ServletRegistration registration ) {
603+ Class <?> dispatcherServlet = ClassUtils
604+ .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" , null );
605+ try {
606+ Class <?> clazz = Class .forName (registration .getClassName ());
607+ return dispatcherServlet .isAssignableFrom (clazz );
608+ }
609+ catch (ClassNotFoundException ex ) {
610+ return false ;
611+ }
612+ }
613+
614+ }
615+
501616}
0 commit comments