1
1
/*
2
- * Copyright 2012-2016 the original author or authors.
2
+ * Copyright 2012-2017 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
17
17
package org .springframework .boot .web .support ;
18
18
19
19
import java .io .IOException ;
20
+ import java .util .Enumeration ;
21
+ import java .util .HashMap ;
22
+ import java .util .Map ;
20
23
21
24
import javax .servlet .RequestDispatcher ;
22
25
import javax .servlet .ServletException ;
35
38
import org .springframework .mock .web .MockFilterConfig ;
36
39
import org .springframework .mock .web .MockHttpServletRequest ;
37
40
import org .springframework .mock .web .MockHttpServletResponse ;
41
+ import org .springframework .mock .web .MockRequestDispatcher ;
38
42
import org .springframework .web .context .request .async .DeferredResult ;
39
43
import org .springframework .web .context .request .async .StandardServletAsyncWebRequest ;
40
44
import org .springframework .web .context .request .async .WebAsyncManager ;
@@ -57,8 +61,7 @@ public class ErrorPageFilterTests {
57
61
58
62
private ErrorPageFilter filter = new ErrorPageFilter ();
59
63
60
- private MockHttpServletRequest request = new MockHttpServletRequest ("GET" ,
61
- "/test/path" );
64
+ private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest ();
62
65
63
66
private MockHttpServletResponse response = new MockHttpServletResponse ();
64
67
@@ -261,8 +264,14 @@ public void doFilter(ServletRequest request, ServletResponse response)
261
264
.isEqualTo (500 );
262
265
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_MESSAGE ))
263
266
.isEqualTo ("BAD" );
264
- assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
267
+ Map <String , Object > requestAttributes = getAttributesForDispatch ("/500" );
268
+ assertThat (requestAttributes .get (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
265
269
.isEqualTo (RuntimeException .class );
270
+ assertThat (requestAttributes .get (RequestDispatcher .ERROR_EXCEPTION ))
271
+ .isInstanceOf (RuntimeException .class );
272
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
273
+ .isNull ();
274
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION )).isNull ();
266
275
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ))
267
276
.isEqualTo ("/test/path" );
268
277
assertThat (this .response .isCommitted ()).isTrue ();
@@ -318,8 +327,14 @@ public void doFilter(ServletRequest request, ServletResponse response)
318
327
.isEqualTo (500 );
319
328
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_MESSAGE ))
320
329
.isEqualTo ("BAD" );
321
- assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
330
+ Map <String , Object > requestAttributes = getAttributesForDispatch ("/500" );
331
+ assertThat (requestAttributes .get (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
322
332
.isEqualTo (IllegalStateException .class );
333
+ assertThat (requestAttributes .get (RequestDispatcher .ERROR_EXCEPTION ))
334
+ .isInstanceOf (IllegalStateException .class );
335
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
336
+ .isNull ();
337
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION )).isNull ();
323
338
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ))
324
339
.isEqualTo ("/test/path" );
325
340
assertThat (this .response .isCommitted ()).isTrue ();
@@ -492,8 +507,14 @@ public void doFilter(ServletRequest request, ServletResponse response)
492
507
.isEqualTo (500 );
493
508
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_MESSAGE ))
494
509
.isEqualTo ("BAD" );
495
- assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
510
+ Map <String , Object > requestAttributes = getAttributesForDispatch ("/500" );
511
+ assertThat (requestAttributes .get (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
496
512
.isEqualTo (RuntimeException .class );
513
+ assertThat (requestAttributes .get (RequestDispatcher .ERROR_EXCEPTION ))
514
+ .isInstanceOf (RuntimeException .class );
515
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION_TYPE ))
516
+ .isNull ();
517
+ assertThat (this .request .getAttribute (RequestDispatcher .ERROR_EXCEPTION )).isNull ();
497
518
assertThat (this .request .getAttribute (RequestDispatcher .ERROR_REQUEST_URI ))
498
519
.isEqualTo ("/test/path" );
499
520
assertThat (this .response .isCommitted ()).isTrue ();
@@ -510,4 +531,60 @@ private void setUpAsyncDispatch() throws Exception {
510
531
asyncManager .startDeferredResultProcessing (result );
511
532
}
512
533
534
+ private Map <String , Object > getAttributesForDispatch (String path ) {
535
+ return this .request .getDispatcher (path ).getRequestAttributes ();
536
+ }
537
+
538
+ private static final class DispatchRecordingMockHttpServletRequest
539
+ extends MockHttpServletRequest {
540
+
541
+ private final Map <String , AttributeCapturingRequestDispatcher > dispatchers = new HashMap <String , AttributeCapturingRequestDispatcher >();
542
+
543
+ private DispatchRecordingMockHttpServletRequest () {
544
+ super ("GET" , "/test/path" );
545
+ }
546
+
547
+ @ Override
548
+ public RequestDispatcher getRequestDispatcher (String path ) {
549
+ AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher (
550
+ path );
551
+ this .dispatchers .put (path , dispatcher );
552
+ return dispatcher ;
553
+ }
554
+
555
+ private AttributeCapturingRequestDispatcher getDispatcher (String path ) {
556
+ return this .dispatchers .get (path );
557
+ }
558
+
559
+ private static final class AttributeCapturingRequestDispatcher
560
+ extends MockRequestDispatcher {
561
+
562
+ private final Map <String , Object > requestAttributes = new HashMap <String , Object >();
563
+
564
+ private AttributeCapturingRequestDispatcher (String resource ) {
565
+ super (resource );
566
+ }
567
+
568
+ @ Override
569
+ public void forward (ServletRequest request , ServletResponse response ) {
570
+ captureAttributes (request );
571
+ super .forward (request , response );
572
+ }
573
+
574
+ private void captureAttributes (ServletRequest request ) {
575
+ Enumeration <String > names = request .getAttributeNames ();
576
+ while (names .hasMoreElements ()) {
577
+ String name = names .nextElement ();
578
+ this .requestAttributes .put (name , request .getAttribute (name ));
579
+ }
580
+ }
581
+
582
+ private Map <String , Object > getRequestAttributes () {
583
+ return this .requestAttributes ;
584
+ }
585
+
586
+ }
587
+
588
+ }
589
+
513
590
}
0 commit comments