19
19
import java .io .IOException ;
20
20
import java .util .ArrayList ;
21
21
import java .util .List ;
22
- import java .util .concurrent .atomic .AtomicBoolean ;
22
+ import java .util .concurrent .atomic .AtomicReference ;
23
23
import java .util .function .Consumer ;
24
24
25
25
import javax .servlet .AsyncContext ;
26
26
import javax .servlet .AsyncEvent ;
27
27
import javax .servlet .AsyncListener ;
28
+ import javax .servlet .ServletOutputStream ;
29
+ import javax .servlet .WriteListener ;
28
30
import javax .servlet .http .HttpServletRequest ;
29
31
import javax .servlet .http .HttpServletResponse ;
32
+ import javax .servlet .http .HttpServletResponseWrapper ;
30
33
31
34
import org .springframework .lang .Nullable ;
32
35
import org .springframework .util .Assert ;
45
48
*/
46
49
public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest , AsyncListener {
47
50
48
- private final AtomicBoolean asyncCompleted = new AtomicBoolean ();
49
-
50
51
private final List <Runnable > timeoutHandlers = new ArrayList <>();
51
52
52
53
private final List <Consumer <Throwable >> exceptionHandlers = new ArrayList <>();
@@ -59,14 +60,43 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
59
60
@ Nullable
60
61
private AsyncContext asyncContext ;
61
62
63
+ private final AtomicReference <State > state ;
64
+
65
+ private volatile boolean hasError ;
66
+
62
67
63
68
/**
64
69
* Create a new instance for the given request/response pair.
65
70
* @param request current HTTP request
66
71
* @param response current HTTP response
67
72
*/
68
73
public StandardServletAsyncWebRequest (HttpServletRequest request , HttpServletResponse response ) {
69
- super (request , response );
74
+ this (request , response , null );
75
+ }
76
+
77
+ /**
78
+ * Constructor to wrap the request and response for the current dispatch that
79
+ * also picks up the state of the last (probably the REQUEST) dispatch.
80
+ * @param request current HTTP request
81
+ * @param response current HTTP response
82
+ * @param previousRequest the existing request from the last dispatch
83
+ * @since 5.3.33
84
+ */
85
+ StandardServletAsyncWebRequest (HttpServletRequest request , HttpServletResponse response ,
86
+ @ Nullable StandardServletAsyncWebRequest previousRequest ) {
87
+
88
+ super (request , new LifecycleHttpServletResponse (response ));
89
+
90
+ if (previousRequest != null ) {
91
+ this .state = previousRequest .state ;
92
+ this .hasError = previousRequest .hasError ;
93
+ }
94
+ else {
95
+ this .state = new AtomicReference <>(State .ACTIVE );
96
+ }
97
+
98
+ //noinspection DataFlowIssue
99
+ ((LifecycleHttpServletResponse ) getResponse ()).setParent (this );
70
100
}
71
101
72
102
@@ -107,7 +137,7 @@ public boolean isAsyncStarted() {
107
137
*/
108
138
@ Override
109
139
public boolean isAsyncComplete () {
110
- return this .asyncCompleted .get ();
140
+ return ( this .state .get () == State . COMPLETED );
111
141
}
112
142
113
143
@ Override
@@ -117,6 +147,7 @@ public void startAsync() {
117
147
"in async request processing. This is done in Java code using the Servlet API " +
118
148
"or by adding \" <async-supported>true</async-supported>\" to servlet and " +
119
149
"filter declarations in web.xml." );
150
+
120
151
Assert .state (!isAsyncComplete (), "Async processing has already completed" );
121
152
122
153
if (isAsyncStarted ()) {
@@ -131,8 +162,10 @@ public void startAsync() {
131
162
132
163
@ Override
133
164
public void dispatch () {
134
- Assert .state (this .asyncContext != null , "Cannot dispatch without an AsyncContext" );
135
- this .asyncContext .dispatch ();
165
+ Assert .state (this .asyncContext != null , "AsyncContext not yet initialized" );
166
+ if (!this .isAsyncComplete ()) {
167
+ this .asyncContext .dispatch ();
168
+ }
136
169
}
137
170
138
171
@@ -151,14 +184,152 @@ public void onTimeout(AsyncEvent event) throws IOException {
151
184
152
185
@ Override
153
186
public void onError (AsyncEvent event ) throws IOException {
187
+ transitionToErrorState ();
154
188
this .exceptionHandlers .forEach (consumer -> consumer .accept (event .getThrowable ()));
155
189
}
156
190
191
+ private void transitionToErrorState () {
192
+ this .hasError = true ;
193
+ this .state .compareAndSet (State .ACTIVE , State .ERROR );
194
+ }
195
+
157
196
@ Override
158
197
public void onComplete (AsyncEvent event ) throws IOException {
159
198
this .completionHandlers .forEach (Runnable ::run );
160
199
this .asyncContext = null ;
161
- this .asyncCompleted .set (true );
200
+ this .state .set (State .COMPLETED );
201
+ }
202
+
203
+
204
+ /**
205
+ * Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
206
+ */
207
+ private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
208
+
209
+ @ Nullable
210
+ private StandardServletAsyncWebRequest parent ;
211
+
212
+ private ServletOutputStream outputStream ;
213
+
214
+ public LifecycleHttpServletResponse (HttpServletResponse response ) {
215
+ super (response );
216
+ }
217
+
218
+ public void setParent (StandardServletAsyncWebRequest parent ) {
219
+ this .parent = parent ;
220
+ }
221
+
222
+ @ Override
223
+ public ServletOutputStream getOutputStream () {
224
+ if (this .outputStream == null ) {
225
+ Assert .notNull (this .parent , "Not initialized" );
226
+ this .outputStream = new LifecycleServletOutputStream ((HttpServletResponse ) getResponse (), this .parent );
227
+ }
228
+ return this .outputStream ;
229
+ }
230
+ }
231
+
232
+
233
+ /**
234
+ * Wraps a ServletOutputStream to prevent use after Servlet container onError
235
+ * notifications, and after async request completion.
236
+ */
237
+ private static final class LifecycleServletOutputStream extends ServletOutputStream {
238
+
239
+ private final HttpServletResponse response ;
240
+
241
+ private final StandardServletAsyncWebRequest parent ;
242
+
243
+ private LifecycleServletOutputStream (HttpServletResponse response , StandardServletAsyncWebRequest parent ) {
244
+ this .response = response ;
245
+ this .parent = parent ;
246
+ }
247
+
248
+ @ Override
249
+ public boolean isReady () {
250
+ return false ;
251
+ }
252
+
253
+ @ Override
254
+ public void setWriteListener (WriteListener writeListener ) {
255
+ }
256
+
257
+ @ Override
258
+ public void write (int b ) throws IOException {
259
+ checkState ();
260
+ try {
261
+ this .response .getOutputStream ().write (b );
262
+ }
263
+ catch (IOException ex ) {
264
+ handleIOException (ex , "ServletOutputStream failed to write" );
265
+ }
266
+ }
267
+
268
+ public void write (byte [] buf , int offset , int len ) throws IOException {
269
+ checkState ();
270
+ try {
271
+ this .response .getOutputStream ().write (buf , offset , len );
272
+ }
273
+ catch (IOException ex ) {
274
+ handleIOException (ex , "ServletOutputStream failed to write" );
275
+ }
276
+ }
277
+
278
+ @ Override
279
+ public void flush () throws IOException {
280
+ checkState ();
281
+ try {
282
+ this .response .getOutputStream ().flush ();
283
+ }
284
+ catch (IOException ex ) {
285
+ handleIOException (ex , "ServletOutputStream failed to flush" );
286
+ }
287
+ }
288
+
289
+ @ Override
290
+ public void close () throws IOException {
291
+ checkState ();
292
+ try {
293
+ this .response .getOutputStream ().close ();
294
+ }
295
+ catch (IOException ex ) {
296
+ handleIOException (ex , "ServletOutputStream failed to close" );
297
+ }
298
+ }
299
+
300
+ private void checkState () throws AsyncRequestNotUsableException {
301
+ if (this .parent .state .get () != State .ACTIVE ) {
302
+ String reason = this .parent .state .get () == State .COMPLETED ?
303
+ "async request completion" : "Servlet container onError notification" ;
304
+ throw new AsyncRequestNotUsableException ("Response not usable after " + reason + "." );
305
+ }
306
+ }
307
+
308
+ private void handleIOException (IOException ex , String msg ) throws AsyncRequestNotUsableException {
309
+ this .parent .transitionToErrorState ();
310
+ throw new AsyncRequestNotUsableException (msg , ex );
311
+ }
312
+
313
+ }
314
+
315
+
316
+ /**
317
+ * Represents a state for {@link StandardServletAsyncWebRequest} to be in.
318
+ * <p><pre>
319
+ * ACTIVE ----+
320
+ * | |
321
+ * v |
322
+ * ERROR |
323
+ * | |
324
+ * v |
325
+ * COMPLETED <--+
326
+ * </pre>
327
+ * @since 5.3.33
328
+ */
329
+ private enum State {
330
+
331
+ ACTIVE , ERROR , COMPLETED
332
+
162
333
}
163
334
164
335
}
0 commit comments