Skip to content

Commit 6f4fb08

Browse files
committed
Invoke WebMvc.fn error handlers for async errors
This commit makes sure that any error handlers registered on the route are also applied when an error occurs asynchronously. This commit applies to asynchronous bodies with both CompletableFuture and Reactive Streams, as well as completely asynchronous responses. Closes gh-26831
1 parent 4c7cc70 commit 6f4fb08

File tree

3 files changed

+47
-16
lines changed

3 files changed

+47
-16
lines changed

spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -118,7 +118,7 @@ private <R> R delegate(Function<ServerResponse, R> function) {
118118
public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context)
119119
throws ServletException, IOException {
120120

121-
writeAsync(request, response, createDeferredResult());
121+
writeAsync(request, response, createDeferredResult(request));
122122
return null;
123123
}
124124

@@ -140,7 +140,7 @@ static void writeAsync(HttpServletRequest request, HttpServletResponse response,
140140

141141
}
142142

143-
private DeferredResult<ServerResponse> createDeferredResult() {
143+
private DeferredResult<ServerResponse> createDeferredResult(HttpServletRequest request) {
144144
DeferredResult<ServerResponse> result;
145145
if (this.timeout != null) {
146146
result = new DeferredResult<>(this.timeout.toMillis());
@@ -153,7 +153,13 @@ private DeferredResult<ServerResponse> createDeferredResult() {
153153
if (ex instanceof CompletionException && ex.getCause() != null) {
154154
ex = ex.getCause();
155155
}
156-
result.setErrorResult(ex);
156+
ServerResponse errorResponse = errorResponse(ex, request);
157+
if (errorResponse != null) {
158+
result.setResult(errorResponse);
159+
}
160+
else {
161+
result.setErrorResult(ex);
162+
}
157163
}
158164
else {
159165
result.setResult(value);

spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,21 +361,27 @@ public CompletionStageEntityResponse(int statusCode, HttpHeaders headers,
361361
protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse,
362362
Context context) throws ServletException, IOException {
363363

364-
DeferredResult<?> deferredResult = createDeferredResult(servletRequest, servletResponse, context);
364+
DeferredResult<ServerResponse> deferredResult = createDeferredResult(servletRequest, servletResponse, context);
365365
DefaultAsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult);
366366
return null;
367367
}
368368

369-
private DeferredResult<?> createDeferredResult(HttpServletRequest request, HttpServletResponse response,
369+
private DeferredResult<ServerResponse> createDeferredResult(HttpServletRequest request, HttpServletResponse response,
370370
Context context) {
371371

372-
DeferredResult<?> result = new DeferredResult<>();
372+
DeferredResult<ServerResponse> result = new DeferredResult<>();
373373
entity().handle((value, ex) -> {
374374
if (ex != null) {
375375
if (ex instanceof CompletionException && ex.getCause() != null) {
376376
ex = ex.getCause();
377377
}
378-
result.setErrorResult(ex);
378+
ServerResponse errorResponse = errorResponse(ex, request);
379+
if (errorResponse != null) {
380+
result.setResult(errorResponse);
381+
}
382+
else {
383+
result.setErrorResult(ex);
384+
}
379385
}
380386
else {
381387
try {
@@ -468,7 +474,12 @@ public void onNext(T t) {
468474

469475
@Override
470476
public void onError(Throwable t) {
471-
this.deferredResult.setErrorResult(t);
477+
try {
478+
handleError(t, this.servletRequest, this.servletResponse, this.context);
479+
}
480+
catch (ServletException | IOException handlingThrowable) {
481+
this.deferredResult.setErrorResult(handlingThrowable);
482+
}
472483
}
473484

474485
@Override

spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -35,7 +35,6 @@
3535

3636
/**
3737
* Base class for {@link ServerResponse} implementations with error handling.
38-
*
3938
* @author Arjen Poutsma
4039
* @since 5.3
4140
*/
@@ -55,21 +54,36 @@ protected final <T extends ServerResponse> void addErrorHandler(Predicate<Throwa
5554
}
5655

5756
@Nullable
58-
protected ModelAndView handleError(Throwable t, HttpServletRequest servletRequest,
57+
protected final ModelAndView handleError(Throwable t, HttpServletRequest servletRequest,
5958
HttpServletResponse servletResponse, Context context) throws ServletException, IOException {
6059

60+
ServerResponse serverResponse = errorResponse(t, servletRequest);
61+
if (serverResponse != null) {
62+
return serverResponse.writeTo(servletRequest, servletResponse, context);
63+
}
64+
else if (t instanceof ServletException) {
65+
throw (ServletException) t;
66+
}
67+
else if (t instanceof IOException) {
68+
throw (IOException) t;
69+
}
70+
else {
71+
throw new ServletException(t);
72+
}
73+
}
74+
75+
@Nullable
76+
protected final ServerResponse errorResponse(Throwable t, HttpServletRequest servletRequest) {
6177
for (ErrorHandler<?> errorHandler : this.errorHandlers) {
6278
if (errorHandler.test(t)) {
6379
ServerRequest serverRequest = (ServerRequest)
6480
servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE);
65-
ServerResponse serverResponse = errorHandler.handle(t, serverRequest);
66-
return serverResponse.writeTo(servletRequest, servletResponse, context);
81+
return errorHandler.handle(t, serverRequest);
6782
}
6883
}
69-
throw new ServletException(t);
84+
return null;
7085
}
7186

72-
7387
private static class ErrorHandler<T extends ServerResponse> {
7488

7589
private final Predicate<Throwable> predicate;

0 commit comments

Comments
 (0)