Skip to content

Commit 5161a75

Browse files
Fixed CORS preflight handling in Play servers (#4991)
This PR continues the work from #4239 (by @sergiuszkierat) which fixed the `ServerCORSTests` to test CORS preflight requests with POST endpoints instead of OPTIONS endpoints. ## Changes - Added CORS-aware OPTIONS request handling to `PlayServerInterpreter` (Play server) - Added CORS-aware OPTIONS request handling to `PlayServerInterpreter` (Play29 server) --------- Co-authored-by: Sergiusz Kierat <sergiusz.kierat@gmail.com>
1 parent f823246 commit 5161a75

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

server/play-server/src/main/scala/sttp/tapir/server/play/PlayServerInterpreter.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import sttp.model.Method
1414
import sttp.monad.FutureMonad
1515
import sttp.tapir.server.ServerEndpoint
1616
import sttp.tapir.server.interceptor.RequestResult
17+
import sttp.tapir.server.interceptor.cors.CORSInterceptor
1718
import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter}
1819
import sttp.tapir.server.model.ServerResponse
1920

@@ -53,6 +54,8 @@ trait PlayServerInterpreter {
5354
playServerOptions.deleteFile
5455
)
5556

57+
val isCORSInterceptorDefined = playServerOptions.interceptors.exists(_.isInstanceOf[CORSInterceptor[Future]])
58+
5659
new PartialFunction[RequestHeader, Handler] {
5760
override def isDefinedAt(request: RequestHeader): Boolean = {
5861
val filtered = filterServerEndpoints(PlayServerRequest(request, request))
@@ -62,7 +65,16 @@ trait PlayServerInterpreter {
6265
// doesn't match, this will be handled by the RejectInterceptor
6366
filtered.exists { e =>
6467
val m = e.endpoint.method
65-
m.isEmpty || m.contains(Method(request.method))
68+
val requestMethod = Method(request.method)
69+
val methodMatches = m.isEmpty || m.contains(requestMethod)
70+
71+
// When CORS interceptor is defined, also accept OPTIONS requests for non-OPTIONS endpoints
72+
// to handle CORS preflight requests
73+
val acceptOptionsForCORS = isCORSInterceptorDefined &&
74+
requestMethod == Method.OPTIONS &&
75+
m.exists(em => em != Method.OPTIONS)
76+
77+
methodMatches || acceptOptionsForCORS
6678
}
6779
} else {
6880
filtered.nonEmpty

server/play29-server/src/main/scala/sttp/tapir/server/play/PlayServerInterpreter.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import sttp.model.Method
1414
import sttp.monad.FutureMonad
1515
import sttp.tapir.server.ServerEndpoint
1616
import sttp.tapir.server.interceptor.RequestResult
17+
import sttp.tapir.server.interceptor.cors.CORSInterceptor
1718
import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter}
1819
import sttp.tapir.server.model.ServerResponse
1920

@@ -53,6 +54,8 @@ trait PlayServerInterpreter {
5354
playServerOptions.deleteFile
5455
)
5556

57+
val isCORSInterceptorDefined = playServerOptions.interceptors.exists(_.isInstanceOf[CORSInterceptor[Future]])
58+
5659
new PartialFunction[RequestHeader, Handler] {
5760
override def isDefinedAt(request: RequestHeader): Boolean = {
5861
val filtered = filterServerEndpoints(PlayServerRequest(request, request))
@@ -62,7 +65,16 @@ trait PlayServerInterpreter {
6265
// doesn't match, this will be handled by the RejectInterceptor
6366
filtered.exists { e =>
6467
val m = e.endpoint.method
65-
m.isEmpty || m.contains(Method(request.method))
68+
val requestMethod = Method(request.method)
69+
val methodMatches = m.isEmpty || m.contains(requestMethod)
70+
71+
// When CORS interceptor is defined, also accept OPTIONS requests for non-OPTIONS endpoints
72+
// to handle CORS preflight requests
73+
val acceptOptionsForCORS = isCORSInterceptorDefined &&
74+
requestMethod == Method.OPTIONS &&
75+
m.exists(em => em != Method.OPTIONS)
76+
77+
methodMatches || acceptOptionsForCORS
6678
}
6779
} else {
6880
filtered.nonEmpty

server/tests/src/main/scala/sttp/tapir/server/tests/ServerCORSTests.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
2222

2323
val preflightTests = List(
2424
testServer(
25-
endpoint.options.in("path").out(stringBody),
25+
endpoint.post.in("path").out(stringBody),
2626
"CORS with default config; valid preflight request",
2727
_.corsInterceptor(CORSInterceptor.default[F])
2828
)(_ => pureResult("foo".asRight[Unit])) { (backend, baseUri) =>
@@ -45,7 +45,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
4545
}
4646
},
4747
testServer(
48-
endpoint.options.in("path"),
48+
endpoint.post.in("path"),
4949
"CORS with specific allowed origin, method, headers, allowed credentials and max age; preflight request with matching origin, method and headers",
5050
_.corsInterceptor(
5151
CORSInterceptor.customOrThrow[F](
@@ -74,7 +74,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
7474
}
7575
},
7676
testServer(
77-
endpoint.options.in("path"),
77+
endpoint.post.in("path"),
7878
"CORS with multiple allowed origins, method, headers, allowed credentials and max age; preflight request with matching origin, method and headers",
7979
_.corsInterceptor(
8080
CORSInterceptor.customOrThrow[F](
@@ -103,7 +103,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
103103
}
104104
},
105105
testServer(
106-
endpoint.options.in("path"),
106+
endpoint.post.in("path"),
107107
"CORS with specific allowed origin; preflight request with unsupported origin",
108108
_.corsInterceptor(CORSInterceptor.customOrThrow[F](CORSConfig.default.allowOrigin(Origin.Host("https", "example.com"))))
109109
)(noop) { (backend, baseUri) =>
@@ -115,7 +115,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
115115
}
116116
},
117117
testServer(
118-
endpoint.options.in("path"),
118+
endpoint.post.in("path"),
119119
"CORS with multiple allowed origins; preflight request with unsupported origin",
120120
_.corsInterceptor(
121121
CORSInterceptor.customOrThrow[F](CORSConfig.default.allowMatchingOrigins(Set("https://example1.com", "https://example2.com")))
@@ -129,7 +129,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
129129
}
130130
},
131131
testServer(
132-
endpoint.options.in("path"),
132+
endpoint.post.in("path"),
133133
"CORS with specific allowed method; preflight request with unsupported method",
134134
_.corsInterceptor(CORSInterceptor.customOrThrow[F](CORSConfig.default.allowMethods(Method.PUT)))
135135
)(noop) { (backend, baseUri) =>
@@ -141,7 +141,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
141141
}
142142
},
143143
testServer(
144-
endpoint.options.in("path"),
144+
endpoint.post.in("path"),
145145
"CORS with specific allowed headers; preflight request with unsupported header",
146146
_.corsInterceptor(CORSInterceptor.customOrThrow[F](CORSConfig.default.allowHeaders("X-Bar")))
147147
)(noop) { (backend, baseUri) =>
@@ -153,7 +153,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
153153
}
154154
},
155155
testServer(
156-
endpoint.options.in("path"),
156+
endpoint.post.in("path"),
157157
"CORS with reflected allowed headers; preflight request",
158158
_.corsInterceptor(CORSInterceptor.customOrThrow[F](CORSConfig.default.reflectHeaders))
159159
)(noop) { (backend, baseUri) =>
@@ -165,7 +165,7 @@ class ServerCORSTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F
165165
}
166166
},
167167
testServer(
168-
endpoint.options.in("path"),
168+
endpoint.post.in("path"),
169169
"CORS with custom response code for preflight requests; valid preflight request",
170170
_.corsInterceptor(CORSInterceptor.customOrThrow[F](CORSConfig.default.preflightResponseStatusCode(StatusCode.Ok)))
171171
)(noop) { (backend, baseUri) =>

0 commit comments

Comments
 (0)