Skip to content

Commit 6deee3e

Browse files
committed
TestDispatcherServlet unwraps to find mock request
Issue: SPR-16695
1 parent d3acf45 commit 6deee3e

File tree

2 files changed

+143
-22
lines changed

2 files changed

+143
-22
lines changed

spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2014 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,6 +25,8 @@
2525
import javax.servlet.http.HttpServletResponse;
2626

2727
import org.springframework.mock.web.MockAsyncContext;
28+
import org.springframework.mock.web.MockHttpServletRequest;
29+
import org.springframework.util.Assert;
2830
import org.springframework.web.context.WebApplicationContext;
2931
import org.springframework.web.context.request.NativeWebRequest;
3032
import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter;
@@ -35,6 +37,7 @@
3537
import org.springframework.web.servlet.DispatcherServlet;
3638
import org.springframework.web.servlet.HandlerExecutionChain;
3739
import org.springframework.web.servlet.ModelAndView;
40+
import org.springframework.web.util.WebUtils;
3841

3942
/**
4043
* A sub-class of {@code DispatcherServlet} that saves the result in an
@@ -64,8 +67,24 @@ protected void service(HttpServletRequest request, HttpServletResponse response)
6467
throws ServletException, IOException {
6568

6669
registerAsyncResultInterceptors(request);
70+
6771
super.service(request, response);
68-
initAsyncDispatchLatch(request);
72+
73+
if (request.getAsyncContext() != null) {
74+
MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
75+
Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
76+
MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext());
77+
Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?");
78+
79+
final CountDownLatch dispatchLatch = new CountDownLatch(1);
80+
mockAsyncContext.addDispatchHandler(new Runnable() {
81+
@Override
82+
public void run() {
83+
dispatchLatch.countDown();
84+
}
85+
});
86+
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
87+
}
6988
}
7089

7190
private void registerAsyncResultInterceptors(final HttpServletRequest request) {
@@ -84,19 +103,6 @@ public <T> void postProcess(NativeWebRequest r, DeferredResult<T> result, Object
84103
});
85104
}
86105

87-
private void initAsyncDispatchLatch(HttpServletRequest request) {
88-
if (request.getAsyncContext() != null) {
89-
final CountDownLatch dispatchLatch = new CountDownLatch(1);
90-
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() {
91-
@Override
92-
public void run() {
93-
dispatchLatch.countDown();
94-
}
95-
});
96-
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
97-
}
98-
}
99-
100106
protected DefaultMvcResult getMvcResult(ServletRequest request) {
101107
return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE);
102108
}

spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2016 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,9 +18,15 @@
1818

1919
import java.io.IOException;
2020
import java.security.Principal;
21+
import java.util.concurrent.CompletableFuture;
22+
import javax.servlet.AsyncContext;
23+
import javax.servlet.AsyncListener;
2124
import javax.servlet.Filter;
2225
import javax.servlet.FilterChain;
26+
import javax.servlet.ServletContext;
2327
import javax.servlet.ServletException;
28+
import javax.servlet.ServletRequest;
29+
import javax.servlet.ServletResponse;
2430
import javax.servlet.http.HttpServletRequest;
2531
import javax.servlet.http.HttpServletRequestWrapper;
2632
import javax.servlet.http.HttpServletResponse;
@@ -29,17 +35,23 @@
2935

3036
import org.junit.Test;
3137

38+
import org.springframework.http.MediaType;
3239
import org.springframework.stereotype.Controller;
3340
import org.springframework.test.web.Person;
41+
import org.springframework.test.web.servlet.MockMvc;
42+
import org.springframework.test.web.servlet.MvcResult;
3443
import org.springframework.validation.Errors;
44+
import org.springframework.web.bind.annotation.GetMapping;
3545
import org.springframework.web.bind.annotation.RequestMapping;
3646
import org.springframework.web.bind.annotation.RequestMethod;
47+
import org.springframework.web.bind.annotation.ResponseBody;
3748
import org.springframework.web.filter.OncePerRequestFilter;
3849
import org.springframework.web.servlet.ModelAndView;
3950
import org.springframework.web.servlet.mvc.support.RedirectAttributes;
4051

4152
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
4253
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
54+
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
4355
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*;
4456

4557
/**
@@ -107,6 +119,22 @@ public void filterWrapsRequestResponse() throws Exception {
107119
.andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME));
108120
}
109121

122+
@Test // SPR-16695
123+
public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception {
124+
MockMvc mockMvc = standaloneSetup(new PersonController())
125+
.addFilters(new WrappingRequestResponseFilter())
126+
.build();
127+
128+
MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON))
129+
.andExpect(request().asyncStarted())
130+
.andExpect(request().asyncResult(new Person("Lukas")))
131+
.andReturn();
132+
133+
mockMvc.perform(asyncDispatch(mvcResult))
134+
.andExpect(status().isOk())
135+
.andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}"));
136+
}
137+
110138

111139
@Controller
112140
private static class PersonController {
@@ -129,6 +157,12 @@ public ModelAndView user(Principal principal) {
129157
public String forward() {
130158
return "forward:/persons";
131159
}
160+
161+
@GetMapping("persons/{id}")
162+
@ResponseBody
163+
public CompletableFuture<Person> getPerson() {
164+
return CompletableFuture.completedFuture(new Person("Lukas"));
165+
}
132166
}
133167

134168
private class ContinueFilter extends OncePerRequestFilter {
@@ -149,15 +183,20 @@ protected void doFilterInternal(HttpServletRequest request,
149183
HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
150184

151185
filterChain.doFilter(new HttpServletRequestWrapper(request) {
186+
152187
@Override
153188
public Principal getUserPrincipal() {
154-
return new Principal() {
155-
@Override
156-
public String getName() {
157-
return PRINCIPAL_NAME;
158-
}
159-
};
189+
return () -> PRINCIPAL_NAME;
160190
}
191+
192+
// Like Spring Security does in HttpServlet3RequestFactory..
193+
194+
@Override
195+
public AsyncContext getAsyncContext() {
196+
return super.getAsyncContext() != null ?
197+
new AsyncContextWrapper(super.getAsyncContext()) : null;
198+
}
199+
161200
}, new HttpServletResponseWrapper(response));
162201
}
163202
}
@@ -170,4 +209,80 @@ protected void doFilterInternal(HttpServletRequest request,
170209
response.sendRedirect("/login");
171210
}
172211
}
212+
213+
214+
private static class AsyncContextWrapper implements AsyncContext {
215+
216+
private final AsyncContext delegate;
217+
218+
public AsyncContextWrapper(AsyncContext delegate) {
219+
this.delegate = delegate;
220+
}
221+
222+
@Override
223+
public ServletRequest getRequest() {
224+
return this.delegate.getRequest();
225+
}
226+
227+
@Override
228+
public ServletResponse getResponse() {
229+
return this.delegate.getResponse();
230+
}
231+
232+
@Override
233+
public boolean hasOriginalRequestAndResponse() {
234+
return this.delegate.hasOriginalRequestAndResponse();
235+
}
236+
237+
@Override
238+
public void dispatch() {
239+
this.delegate.dispatch();
240+
}
241+
242+
@Override
243+
public void dispatch(String path) {
244+
this.delegate.dispatch(path);
245+
}
246+
247+
@Override
248+
public void dispatch(ServletContext context, String path) {
249+
this.delegate.dispatch(context, path);
250+
}
251+
252+
@Override
253+
public void complete() {
254+
this.delegate.complete();
255+
}
256+
257+
@Override
258+
public void start(Runnable run) {
259+
this.delegate.start(run);
260+
}
261+
262+
@Override
263+
public void addListener(AsyncListener listener) {
264+
this.delegate.addListener(listener);
265+
}
266+
267+
@Override
268+
public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) {
269+
this.delegate.addListener(listener, req, res);
270+
}
271+
272+
@Override
273+
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
274+
return this.delegate.createListener(clazz);
275+
}
276+
277+
@Override
278+
public void setTimeout(long timeout) {
279+
this.delegate.setTimeout(timeout);
280+
}
281+
282+
@Override
283+
public long getTimeout() {
284+
return this.delegate.getTimeout();
285+
}
286+
}
287+
173288
}

0 commit comments

Comments
 (0)