1
1
/*
2
- * Copyright 2002-2016 the original author or authors.
2
+ * Copyright 2002-2018 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
18
18
19
19
import java .io .IOException ;
20
20
import java .security .Principal ;
21
+ import java .util .concurrent .CompletableFuture ;
22
+ import javax .servlet .AsyncContext ;
23
+ import javax .servlet .AsyncListener ;
21
24
import javax .servlet .Filter ;
22
25
import javax .servlet .FilterChain ;
26
+ import javax .servlet .ServletContext ;
23
27
import javax .servlet .ServletException ;
28
+ import javax .servlet .ServletRequest ;
29
+ import javax .servlet .ServletResponse ;
24
30
import javax .servlet .http .HttpServletRequest ;
25
31
import javax .servlet .http .HttpServletRequestWrapper ;
26
32
import javax .servlet .http .HttpServletResponse ;
29
35
30
36
import org .junit .Test ;
31
37
38
+ import org .springframework .http .MediaType ;
32
39
import org .springframework .stereotype .Controller ;
33
40
import org .springframework .test .web .Person ;
41
+ import org .springframework .test .web .servlet .MockMvc ;
42
+ import org .springframework .test .web .servlet .MvcResult ;
34
43
import org .springframework .validation .Errors ;
44
+ import org .springframework .web .bind .annotation .GetMapping ;
35
45
import org .springframework .web .bind .annotation .RequestMapping ;
36
46
import org .springframework .web .bind .annotation .RequestMethod ;
47
+ import org .springframework .web .bind .annotation .ResponseBody ;
37
48
import org .springframework .web .filter .OncePerRequestFilter ;
38
49
import org .springframework .web .servlet .ModelAndView ;
39
50
import org .springframework .web .servlet .mvc .support .RedirectAttributes ;
40
51
41
52
import static org .springframework .test .web .servlet .request .MockMvcRequestBuilders .*;
42
53
import static org .springframework .test .web .servlet .result .MockMvcResultMatchers .*;
54
+ import static org .springframework .test .web .servlet .result .MockMvcResultMatchers .request ;
43
55
import static org .springframework .test .web .servlet .setup .MockMvcBuilders .*;
44
56
45
57
/**
@@ -107,6 +119,22 @@ public void filterWrapsRequestResponse() throws Exception {
107
119
.andExpect (model ().attribute ("principal" , WrappingRequestResponseFilter .PRINCIPAL_NAME ));
108
120
}
109
121
122
+ @ Test // SPR-16695
123
+ public void filterWrapsRequestResponseAndPerformsAsyncDispatch () throws Exception {
124
+ MockMvc mockMvc = standaloneSetup (new PersonController ())
125
+ .addFilters (new WrappingRequestResponseFilter ())
126
+ .build ();
127
+
128
+ MvcResult mvcResult = mockMvc .perform (get ("/persons/1" ).accept (MediaType .APPLICATION_JSON ))
129
+ .andExpect (request ().asyncStarted ())
130
+ .andExpect (request ().asyncResult (new Person ("Lukas" )))
131
+ .andReturn ();
132
+
133
+ mockMvc .perform (asyncDispatch (mvcResult ))
134
+ .andExpect (status ().isOk ())
135
+ .andExpect (content ().string ("{\" name\" :\" Lukas\" ,\" someDouble\" :0.0,\" someBoolean\" :false}" ));
136
+ }
137
+
110
138
111
139
@ Controller
112
140
private static class PersonController {
@@ -129,6 +157,12 @@ public ModelAndView user(Principal principal) {
129
157
public String forward () {
130
158
return "forward:/persons" ;
131
159
}
160
+
161
+ @ GetMapping ("persons/{id}" )
162
+ @ ResponseBody
163
+ public CompletableFuture <Person > getPerson () {
164
+ return CompletableFuture .completedFuture (new Person ("Lukas" ));
165
+ }
132
166
}
133
167
134
168
private class ContinueFilter extends OncePerRequestFilter {
@@ -149,15 +183,20 @@ protected void doFilterInternal(HttpServletRequest request,
149
183
HttpServletResponse response , FilterChain filterChain ) throws ServletException , IOException {
150
184
151
185
filterChain .doFilter (new HttpServletRequestWrapper (request ) {
186
+
152
187
@ Override
153
188
public Principal getUserPrincipal () {
154
- return new Principal () {
155
- @ Override
156
- public String getName () {
157
- return PRINCIPAL_NAME ;
158
- }
159
- };
189
+ return () -> PRINCIPAL_NAME ;
160
190
}
191
+
192
+ // Like Spring Security does in HttpServlet3RequestFactory..
193
+
194
+ @ Override
195
+ public AsyncContext getAsyncContext () {
196
+ return super .getAsyncContext () != null ?
197
+ new AsyncContextWrapper (super .getAsyncContext ()) : null ;
198
+ }
199
+
161
200
}, new HttpServletResponseWrapper (response ));
162
201
}
163
202
}
@@ -170,4 +209,80 @@ protected void doFilterInternal(HttpServletRequest request,
170
209
response .sendRedirect ("/login" );
171
210
}
172
211
}
212
+
213
+
214
+ private static class AsyncContextWrapper implements AsyncContext {
215
+
216
+ private final AsyncContext delegate ;
217
+
218
+ public AsyncContextWrapper (AsyncContext delegate ) {
219
+ this .delegate = delegate ;
220
+ }
221
+
222
+ @ Override
223
+ public ServletRequest getRequest () {
224
+ return this .delegate .getRequest ();
225
+ }
226
+
227
+ @ Override
228
+ public ServletResponse getResponse () {
229
+ return this .delegate .getResponse ();
230
+ }
231
+
232
+ @ Override
233
+ public boolean hasOriginalRequestAndResponse () {
234
+ return this .delegate .hasOriginalRequestAndResponse ();
235
+ }
236
+
237
+ @ Override
238
+ public void dispatch () {
239
+ this .delegate .dispatch ();
240
+ }
241
+
242
+ @ Override
243
+ public void dispatch (String path ) {
244
+ this .delegate .dispatch (path );
245
+ }
246
+
247
+ @ Override
248
+ public void dispatch (ServletContext context , String path ) {
249
+ this .delegate .dispatch (context , path );
250
+ }
251
+
252
+ @ Override
253
+ public void complete () {
254
+ this .delegate .complete ();
255
+ }
256
+
257
+ @ Override
258
+ public void start (Runnable run ) {
259
+ this .delegate .start (run );
260
+ }
261
+
262
+ @ Override
263
+ public void addListener (AsyncListener listener ) {
264
+ this .delegate .addListener (listener );
265
+ }
266
+
267
+ @ Override
268
+ public void addListener (AsyncListener listener , ServletRequest req , ServletResponse res ) {
269
+ this .delegate .addListener (listener , req , res );
270
+ }
271
+
272
+ @ Override
273
+ public <T extends AsyncListener > T createListener (Class <T > clazz ) throws ServletException {
274
+ return this .delegate .createListener (clazz );
275
+ }
276
+
277
+ @ Override
278
+ public void setTimeout (long timeout ) {
279
+ this .delegate .setTimeout (timeout );
280
+ }
281
+
282
+ @ Override
283
+ public long getTimeout () {
284
+ return this .delegate .getTimeout ();
285
+ }
286
+ }
287
+
173
288
}
0 commit comments