1
1
/*
2
- * Copyright 2002-2022 the original author or authors.
2
+ * Copyright 2002-2023 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
26
26
import javax .servlet .DispatcherType ;
27
27
import javax .servlet .ServletContext ;
28
28
import javax .servlet .ServletRegistration ;
29
+ import javax .servlet .http .HttpServletRequest ;
29
30
30
31
import org .springframework .beans .factory .NoSuchBeanDefinitionException ;
31
32
import org .springframework .context .ApplicationContext ;
@@ -321,11 +322,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
321
322
if (!hasDispatcherServlet (registrations )) {
322
323
return requestMatchers (RequestMatchers .antMatchersAsArray (method , patterns ));
323
324
}
324
- if (registrations .size () > 1 ) {
325
- String errorMessage = computeErrorMessage (registrations .values ());
326
- throw new IllegalArgumentException (errorMessage );
325
+ ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet (registrations );
326
+ if (dispatcherServlet != null ) {
327
+ if (registrations .size () == 1 ) {
328
+ return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
329
+ }
330
+ List <RequestMatcher > matchers = new ArrayList <>();
331
+ for (String pattern : patterns ) {
332
+ AntPathRequestMatcher ant = new AntPathRequestMatcher (pattern , (method != null ) ? method .name () : null );
333
+ MvcRequestMatcher mvc = createMvcMatchers (method , pattern ).get (0 );
334
+ matchers .add (new DispatcherServletDelegatingRequestMatcher (ant , mvc , servletContext ));
335
+ }
336
+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
327
337
}
328
- return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
338
+ dispatcherServlet = requireOnlyPathMappedDispatcherServlet (registrations );
339
+ if (dispatcherServlet != null ) {
340
+ String mapping = dispatcherServlet .getMappings ().iterator ().next ();
341
+ List <MvcRequestMatcher > matchers = createMvcMatchers (method , patterns );
342
+ for (MvcRequestMatcher matcher : matchers ) {
343
+ matcher .setServletPath (mapping .substring (0 , mapping .length () - 2 ));
344
+ }
345
+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
346
+ }
347
+ String errorMessage = computeErrorMessage (registrations .values ());
348
+ throw new IllegalArgumentException (errorMessage );
329
349
}
330
350
331
351
private Map <String , ? extends ServletRegistration > mappableServletRegistrations (ServletContext servletContext ) {
@@ -343,22 +363,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
343
363
if (registrations == null ) {
344
364
return false ;
345
365
}
346
- Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
347
- null );
348
366
for (ServletRegistration registration : registrations .values ()) {
349
- try {
350
- Class <?> clazz = Class .forName (registration .getClassName ());
351
- if (dispatcherServlet .isAssignableFrom (clazz )) {
352
- return true ;
353
- }
354
- }
355
- catch (ClassNotFoundException ex ) {
356
- return false ;
367
+ if (isDispatcherServlet (registration )) {
368
+ return true ;
357
369
}
358
370
}
359
371
return false ;
360
372
}
361
373
374
+ private ServletRegistration requireOneRootDispatcherServlet (
375
+ Map <String , ? extends ServletRegistration > registrations ) {
376
+ ServletRegistration rootDispatcherServlet = null ;
377
+ for (ServletRegistration registration : registrations .values ()) {
378
+ if (!isDispatcherServlet (registration )) {
379
+ continue ;
380
+ }
381
+ if (registration .getMappings ().size () > 1 ) {
382
+ return null ;
383
+ }
384
+ if (!"/" .equals (registration .getMappings ().iterator ().next ())) {
385
+ return null ;
386
+ }
387
+ rootDispatcherServlet = registration ;
388
+ }
389
+ return rootDispatcherServlet ;
390
+ }
391
+
392
+ private ServletRegistration requireOnlyPathMappedDispatcherServlet (
393
+ Map <String , ? extends ServletRegistration > registrations ) {
394
+ ServletRegistration pathDispatcherServlet = null ;
395
+ for (ServletRegistration registration : registrations .values ()) {
396
+ if (!isDispatcherServlet (registration )) {
397
+ return null ;
398
+ }
399
+ if (registration .getMappings ().size () > 1 ) {
400
+ return null ;
401
+ }
402
+ String mapping = registration .getMappings ().iterator ().next ();
403
+ if (!mapping .startsWith ("/" ) || !mapping .endsWith ("/*" )) {
404
+ return null ;
405
+ }
406
+ if (pathDispatcherServlet != null ) {
407
+ return null ;
408
+ }
409
+ pathDispatcherServlet = registration ;
410
+ }
411
+ return pathDispatcherServlet ;
412
+ }
413
+
414
+ private boolean isDispatcherServlet (ServletRegistration registration ) {
415
+ Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
416
+ null );
417
+ try {
418
+ Class <?> clazz = Class .forName (registration .getClassName ());
419
+ return dispatcherServlet .isAssignableFrom (clazz );
420
+ }
421
+ catch (ClassNotFoundException ex ) {
422
+ return false ;
423
+ }
424
+ }
425
+
362
426
private String computeErrorMessage (Collection <? extends ServletRegistration > registrations ) {
363
427
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
364
428
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -498,4 +562,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
498
562
499
563
}
500
564
565
+ static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
566
+
567
+ private final AntPathRequestMatcher ant ;
568
+
569
+ private final MvcRequestMatcher mvc ;
570
+
571
+ private final ServletContext servletContext ;
572
+
573
+ DispatcherServletDelegatingRequestMatcher (AntPathRequestMatcher ant , MvcRequestMatcher mvc ,
574
+ ServletContext servletContext ) {
575
+ this .ant = ant ;
576
+ this .mvc = mvc ;
577
+ this .servletContext = servletContext ;
578
+ }
579
+
580
+ @ Override
581
+ public boolean matches (HttpServletRequest request ) {
582
+ String name = request .getHttpServletMapping ().getServletName ();
583
+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
584
+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
585
+ if (isDispatcherServlet (registration )) {
586
+ return this .mvc .matches (request );
587
+ }
588
+ return this .ant .matches (request );
589
+ }
590
+
591
+ @ Override
592
+ public MatchResult matcher (HttpServletRequest request ) {
593
+ String name = request .getHttpServletMapping ().getServletName ();
594
+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
595
+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
596
+ if (isDispatcherServlet (registration )) {
597
+ return this .mvc .matcher (request );
598
+ }
599
+ return this .ant .matcher (request );
600
+ }
601
+
602
+ private boolean isDispatcherServlet (ServletRegistration registration ) {
603
+ Class <?> dispatcherServlet = ClassUtils
604
+ .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" , null );
605
+ try {
606
+ Class <?> clazz = Class .forName (registration .getClassName ());
607
+ return dispatcherServlet .isAssignableFrom (clazz );
608
+ }
609
+ catch (ClassNotFoundException ex ) {
610
+ return false ;
611
+ }
612
+ }
613
+
614
+ }
615
+
501
616
}
0 commit comments