Skip to content

Commit 3ab270e

Browse files
ShindongLeesdeleuze
authored andcommitted
Fix parameter bug of handler inside the filterFunction DSL
Co-authored-by: hojongs <[email protected]> Co-authored-by: bjh970913 <[email protected]> Closes gh-26921
1 parent ddb727b commit 3ab270e

File tree

4 files changed

+45
-4
lines changed

4 files changed

+45
-4
lines changed

spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
531531
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
532532
builder.filter { serverRequest, handlerFunction ->
533533
mono(Dispatchers.Unconfined) {
534-
filterFunction(serverRequest) {
535-
handlerFunction.handle(serverRequest).awaitFirst()
534+
filterFunction(serverRequest) { handlerRequest ->
535+
handlerFunction.handle(handlerRequest).awaitFirst()
536536
}
537537
}
538538
}

spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ class CoRouterFunctionDslTests {
152152
}
153153
}
154154

155+
@Test
156+
fun filtering() {
157+
val mockRequest = get("https://example.com/filter").build()
158+
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
159+
StepVerifier.create(sampleRouter().route(request).flatMap { it.handle(request) })
160+
.expectNextMatches { response ->
161+
response.headers().getFirst("foo") == "bar"
162+
}
163+
.verifyComplete()
164+
}
155165

156166
private fun sampleRouter() = coRouter {
157167
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
@@ -186,6 +196,18 @@ class CoRouterFunctionDslTests {
186196
path("/baz", ::handle)
187197
GET("/rendering") { RenderingResponse.create("index").buildAndAwait() }
188198
add(otherRouter)
199+
add(filterRouter)
200+
}
201+
202+
private val filterRouter = coRouter {
203+
"/filter" { request ->
204+
ok().header("foo", request.headers().firstHeader("foo")).buildAndAwait()
205+
}
206+
207+
filter { request, next ->
208+
val newRequest = ServerRequest.from(request).apply { header("foo", "bar") }.build()
209+
next(newRequest)
210+
}
189211
}
190212

191213
private val otherRouter = router {

spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,8 @@ class RouterFunctionDsl internal constructor (private val init: (RouterFunctionD
523523
*/
524524
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) {
525525
builder.filter { request, next ->
526-
filterFunction(request) {
527-
next.handle(request)
526+
filterFunction(request) { handlerRequest ->
527+
next.handle(handlerRequest)
528528
}
529529
}
530530
}

spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ class RouterFunctionDslTests {
126126
}
127127
}
128128

129+
@Test
130+
fun filtering() {
131+
val servletRequest = MockHttpServletRequest("GET", "/filter")
132+
val request = DefaultServerRequest(servletRequest, emptyList())
133+
assertThat(sampleRouter().route(request).get().handle(request).headers().getFirst("foo")).isEqualTo("bar")
134+
}
135+
129136
private fun sampleRouter() = router {
130137
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
131138
"/api".nest {
@@ -159,6 +166,18 @@ class RouterFunctionDslTests {
159166
path("/baz", ::handle)
160167
GET("/rendering") { RenderingResponse.create("index").build() }
161168
add(otherRouter)
169+
add(filterRouter)
170+
}
171+
172+
private val filterRouter = router {
173+
"/filter" { request ->
174+
ok().header("foo", request.headers().firstHeader("foo")).build()
175+
}
176+
177+
filter { request, next ->
178+
val newRequest = ServerRequest.from(request).apply { header("foo", "bar") }.build()
179+
next(newRequest)
180+
}
162181
}
163182

164183
private val otherRouter = router {

0 commit comments

Comments
 (0)