Skip to content

Commit 6432b13

Browse files
committed
Add state and response wrapping to StandardServletAsyncWebRequest
The wrapped response prevents use after AsyncListener onError or completion to ensure compliance with Servlet Spec 2.3.3.4. See gh-32342
1 parent 3478a70 commit 6432b13

File tree

6 files changed

+254
-21
lines changed

6 files changed

+254
-21
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.web.context.request.async;
18+
19+
import java.io.IOException;
20+
21+
/**
22+
* Raised when the response for an asynchronous request becomes unusable as
23+
* indicated by a write failure, or a Servlet container error notification, or
24+
* after the async request has completed.
25+
*
26+
* <p>The exception relies on response wrapping, and on {@code AsyncListener}
27+
* notifications, managed by {@link StandardServletAsyncWebRequest}.
28+
*
29+
* @author Rossen Stoyanchev
30+
* @since 5.3.33
31+
*/
32+
@SuppressWarnings("serial")
33+
public class AsyncRequestNotUsableException extends IOException {
34+
35+
36+
public AsyncRequestNotUsableException(String message) {
37+
super(message);
38+
}
39+
40+
public AsyncRequestNotUsableException(String message, Throwable cause) {
41+
super(message, cause);
42+
}
43+
44+
}

spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java

Lines changed: 179 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
import java.io.IOException;
2020
import java.util.ArrayList;
2121
import java.util.List;
22-
import java.util.concurrent.atomic.AtomicBoolean;
22+
import java.util.concurrent.atomic.AtomicReference;
2323
import java.util.function.Consumer;
2424

2525
import javax.servlet.AsyncContext;
2626
import javax.servlet.AsyncEvent;
2727
import javax.servlet.AsyncListener;
28+
import javax.servlet.ServletOutputStream;
29+
import javax.servlet.WriteListener;
2830
import javax.servlet.http.HttpServletRequest;
2931
import javax.servlet.http.HttpServletResponse;
32+
import javax.servlet.http.HttpServletResponseWrapper;
3033

3134
import org.springframework.lang.Nullable;
3235
import org.springframework.util.Assert;
@@ -45,8 +48,6 @@
4548
*/
4649
public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest, AsyncListener {
4750

48-
private final AtomicBoolean asyncCompleted = new AtomicBoolean();
49-
5051
private final List<Runnable> timeoutHandlers = new ArrayList<>();
5152

5253
private final List<Consumer<Throwable>> exceptionHandlers = new ArrayList<>();
@@ -59,14 +60,43 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
5960
@Nullable
6061
private AsyncContext asyncContext;
6162

63+
private final AtomicReference<State> state;
64+
65+
private volatile boolean hasError;
66+
6267

6368
/**
6469
* Create a new instance for the given request/response pair.
6570
* @param request current HTTP request
6671
* @param response current HTTP response
6772
*/
6873
public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
69-
super(request, response);
74+
this(request, response, null);
75+
}
76+
77+
/**
78+
* Constructor to wrap the request and response for the current dispatch that
79+
* also picks up the state of the last (probably the REQUEST) dispatch.
80+
* @param request current HTTP request
81+
* @param response current HTTP response
82+
* @param previousRequest the existing request from the last dispatch
83+
* @since 5.3.33
84+
*/
85+
StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response,
86+
@Nullable StandardServletAsyncWebRequest previousRequest) {
87+
88+
super(request, new LifecycleHttpServletResponse(response));
89+
90+
if (previousRequest != null) {
91+
this.state = previousRequest.state;
92+
this.hasError = previousRequest.hasError;
93+
}
94+
else {
95+
this.state = new AtomicReference<>(State.ACTIVE);
96+
}
97+
98+
//noinspection DataFlowIssue
99+
((LifecycleHttpServletResponse) getResponse()).setParent(this);
70100
}
71101

72102

@@ -107,7 +137,7 @@ public boolean isAsyncStarted() {
107137
*/
108138
@Override
109139
public boolean isAsyncComplete() {
110-
return this.asyncCompleted.get();
140+
return (this.state.get() == State.COMPLETED);
111141
}
112142

113143
@Override
@@ -117,6 +147,7 @@ public void startAsync() {
117147
"in async request processing. This is done in Java code using the Servlet API " +
118148
"or by adding \"<async-supported>true</async-supported>\" to servlet and " +
119149
"filter declarations in web.xml.");
150+
120151
Assert.state(!isAsyncComplete(), "Async processing has already completed");
121152

122153
if (isAsyncStarted()) {
@@ -131,8 +162,10 @@ public void startAsync() {
131162

132163
@Override
133164
public void dispatch() {
134-
Assert.state(this.asyncContext != null, "Cannot dispatch without an AsyncContext");
135-
this.asyncContext.dispatch();
165+
Assert.state(this.asyncContext != null, "AsyncContext not yet initialized");
166+
if (!this.isAsyncComplete()) {
167+
this.asyncContext.dispatch();
168+
}
136169
}
137170

138171

@@ -151,14 +184,152 @@ public void onTimeout(AsyncEvent event) throws IOException {
151184

152185
@Override
153186
public void onError(AsyncEvent event) throws IOException {
187+
transitionToErrorState();
154188
this.exceptionHandlers.forEach(consumer -> consumer.accept(event.getThrowable()));
155189
}
156190

191+
private void transitionToErrorState() {
192+
this.hasError = true;
193+
this.state.compareAndSet(State.ACTIVE, State.ERROR);
194+
}
195+
157196
@Override
158197
public void onComplete(AsyncEvent event) throws IOException {
159198
this.completionHandlers.forEach(Runnable::run);
160199
this.asyncContext = null;
161-
this.asyncCompleted.set(true);
200+
this.state.set(State.COMPLETED);
201+
}
202+
203+
204+
/**
205+
* Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
206+
*/
207+
private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
208+
209+
@Nullable
210+
private StandardServletAsyncWebRequest parent;
211+
212+
private ServletOutputStream outputStream;
213+
214+
public LifecycleHttpServletResponse(HttpServletResponse response) {
215+
super(response);
216+
}
217+
218+
public void setParent(StandardServletAsyncWebRequest parent) {
219+
this.parent = parent;
220+
}
221+
222+
@Override
223+
public ServletOutputStream getOutputStream() {
224+
if (this.outputStream == null) {
225+
Assert.notNull(this.parent, "Not initialized");
226+
this.outputStream = new LifecycleServletOutputStream((HttpServletResponse) getResponse(), this.parent);
227+
}
228+
return this.outputStream;
229+
}
230+
}
231+
232+
233+
/**
234+
* Wraps a ServletOutputStream to prevent use after Servlet container onError
235+
* notifications, and after async request completion.
236+
*/
237+
private static final class LifecycleServletOutputStream extends ServletOutputStream {
238+
239+
private final HttpServletResponse response;
240+
241+
private final StandardServletAsyncWebRequest parent;
242+
243+
private LifecycleServletOutputStream(HttpServletResponse response, StandardServletAsyncWebRequest parent) {
244+
this.response = response;
245+
this.parent = parent;
246+
}
247+
248+
@Override
249+
public boolean isReady() {
250+
return false;
251+
}
252+
253+
@Override
254+
public void setWriteListener(WriteListener writeListener) {
255+
}
256+
257+
@Override
258+
public void write(int b) throws IOException {
259+
checkState();
260+
try {
261+
this.response.getOutputStream().write(b);
262+
}
263+
catch (IOException ex) {
264+
handleIOException(ex, "ServletOutputStream failed to write");
265+
}
266+
}
267+
268+
public void write(byte[] buf, int offset, int len) throws IOException {
269+
checkState();
270+
try {
271+
this.response.getOutputStream().write(buf, offset, len);
272+
}
273+
catch (IOException ex) {
274+
handleIOException(ex, "ServletOutputStream failed to write");
275+
}
276+
}
277+
278+
@Override
279+
public void flush() throws IOException {
280+
checkState();
281+
try {
282+
this.response.getOutputStream().flush();
283+
}
284+
catch (IOException ex) {
285+
handleIOException(ex, "ServletOutputStream failed to flush");
286+
}
287+
}
288+
289+
@Override
290+
public void close() throws IOException {
291+
checkState();
292+
try {
293+
this.response.getOutputStream().close();
294+
}
295+
catch (IOException ex) {
296+
handleIOException(ex, "ServletOutputStream failed to close");
297+
}
298+
}
299+
300+
private void checkState() throws AsyncRequestNotUsableException {
301+
if (this.parent.state.get() != State.ACTIVE) {
302+
String reason = this.parent.state.get() == State.COMPLETED ?
303+
"async request completion" : "Servlet container onError notification";
304+
throw new AsyncRequestNotUsableException("Response not usable after " + reason + ".");
305+
}
306+
}
307+
308+
private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
309+
this.parent.transitionToErrorState();
310+
throw new AsyncRequestNotUsableException(msg, ex);
311+
}
312+
313+
}
314+
315+
316+
/**
317+
* Represents a state for {@link StandardServletAsyncWebRequest} to be in.
318+
* <p><pre>
319+
* ACTIVE ----+
320+
* | |
321+
* v |
322+
* ERROR |
323+
* | |
324+
* v |
325+
* COMPLETED <--+
326+
* </pre>
327+
* @since 5.3.33
328+
*/
329+
private enum State {
330+
331+
ACTIVE, ERROR, COMPLETED
332+
162333
}
163334

164335
}

spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ public void setAsyncWebRequest(AsyncWebRequest asyncWebRequest) {
132132
WebAsyncUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, RequestAttributes.SCOPE_REQUEST));
133133
}
134134

135+
/**
136+
* Return the current {@link AsyncWebRequest}.
137+
* @since 5.3.33
138+
*/
139+
@Nullable
140+
public AsyncWebRequest getAsyncWebRequest() {
141+
return this.asyncWebRequest;
142+
}
143+
135144
/**
136145
* Configure an AsyncTaskExecutor for use with concurrent processing via
137146
* {@link #startCallableProcessing(Callable, Object...)}.

spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -82,7 +82,10 @@ public static WebAsyncManager getAsyncManager(WebRequest webRequest) {
8282
* @return an AsyncWebRequest instance (never {@code null})
8383
*/
8484
public static AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
85-
return new StandardServletAsyncWebRequest(request, response);
85+
AsyncWebRequest prev = getAsyncManager(request).getAsyncWebRequest();
86+
return (prev instanceof StandardServletAsyncWebRequest ?
87+
new StandardServletAsyncWebRequest(request, response, (StandardServletAsyncWebRequest) prev) :
88+
new StandardServletAsyncWebRequest(request, response));
8689
}
8790

8891
}

spring-web/src/main/java/org/springframework/web/util/DisconnectedClientHelper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ public class DisconnectedClientHelper {
4141
new HashSet<>(Arrays.asList("broken pipe", "connection reset by peer"));
4242

4343
private static final Set<String> EXCEPTION_TYPE_NAMES =
44-
new HashSet<>(Arrays.asList("AbortedException", "ClientAbortException", "EOFException", "EofException"));
44+
new HashSet<>(Arrays.asList("AbortedException", "ClientAbortException",
45+
"EOFException", "EofException", "AsyncRequestNotUsableException"));
4546

4647
private final Log logger;
4748

spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,21 @@ private SessionAttributesHandler getSessionAttributesHandler(HandlerMethod handl
853853
protected ModelAndView invokeHandlerMethod(HttpServletRequest request,
854854
HttpServletResponse response, HandlerMethod handlerMethod) throws Exception {
855855

856-
ServletWebRequest webRequest = new ServletWebRequest(request, response);
856+
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
857+
AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
858+
asyncWebRequest.setTimeout(this.asyncRequestTimeout);
859+
860+
asyncManager.setTaskExecutor(this.taskExecutor);
861+
asyncManager.setAsyncWebRequest(asyncWebRequest);
862+
asyncManager.registerCallableInterceptors(this.callableInterceptors);
863+
asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors);
864+
865+
// Obtain wrapped response to enforce lifecycle rule from Servlet spec, section 2.3.3.4
866+
response = asyncWebRequest.getNativeResponse(HttpServletResponse.class);
867+
868+
ServletWebRequest webRequest = (asyncWebRequest instanceof ServletWebRequest ?
869+
(ServletWebRequest) asyncWebRequest : new ServletWebRequest(request, response));
870+
857871
try {
858872
WebDataBinderFactory binderFactory = getDataBinderFactory(handlerMethod);
859873
ModelFactory modelFactory = getModelFactory(handlerMethod, binderFactory);
@@ -873,15 +887,6 @@ protected ModelAndView invokeHandlerMethod(HttpServletRequest request,
873887
modelFactory.initModel(webRequest, mavContainer, invocableMethod);
874888
mavContainer.setIgnoreDefaultModelOnRedirect(this.ignoreDefaultModelOnRedirect);
875889

876-
AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
877-
asyncWebRequest.setTimeout(this.asyncRequestTimeout);
878-
879-
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
880-
asyncManager.setTaskExecutor(this.taskExecutor);
881-
asyncManager.setAsyncWebRequest(asyncWebRequest);
882-
asyncManager.registerCallableInterceptors(this.callableInterceptors);
883-
asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors);
884-
885890
if (asyncManager.hasConcurrentResult()) {
886891
Object result = asyncManager.getConcurrentResult();
887892
Object[] resultContext = asyncManager.getConcurrentResultContext();

0 commit comments

Comments
 (0)