@@ -16,18 +16,17 @@ import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}
1616import java .net .{InetSocketAddress , SocketAddress }
1717import java .nio .file .Path
1818import java .util .concurrent .atomic .AtomicBoolean
19- import java .util .concurrent .{Executors , Future as JFuture }
2019import scala .concurrent .duration .FiniteDuration
2120import scala .concurrent .{Future , Promise }
2221import scala .util .control .NonFatal
22+ import java .util .concurrent .atomic .AtomicReference
2323
2424case class NettySyncServer (
2525 serverEndpoints : Vector [ServerEndpoint [OxStreams & WebSockets , Identity ]],
2626 otherRoutes : Vector [IdRoute ],
2727 options : NettySyncServerOptions ,
2828 config : NettyConfig
2929):
30- private val executor = Executors .newVirtualThreadPerTaskExecutor()
3130 private val logger = LoggerFactory .getLogger(getClass.getName)
3231
3332 def addEndpoint (se : ServerEndpoint [OxStreams & WebSockets , Identity ]): NettySyncServer = addEndpoints(List (se))
@@ -93,11 +92,6 @@ case class NettySyncServer(
9392 never
9493 }
9594
96- private [netty] def start (route : Route [Identity ]): NettySyncServerBinding =
97- startUsingSocketOverride[InetSocketAddress ](route, None ) match
98- case (socket, stop) =>
99- NettySyncServerBinding (socket, stop)
100-
10195 def startUsingDomainSocket (path : Path )(using Ox ): NettySyncDomainSocketBinding =
10296 startUsingSocketOverride(Some (new DomainSocketAddress (path.toFile)), inScopeRunner()) match
10397 case (socket, stop) =>
@@ -109,28 +103,52 @@ case class NettySyncServer(
109103 ): (SA , () => Unit ) =
110104 val endpointRoute = NettySyncServerInterpreter (options).toRoute(serverEndpoints.toList, inScopeRunner)
111105 val route = Route .combine(endpointRoute +: otherRoutes)
112- startUsingSocketOverride(route, socketOverride)
106+ startUsingSocketOverride(route, socketOverride, inScopeRunner )
113107
114- private def startUsingSocketOverride [SA <: SocketAddress ](route : Route [Identity ], socketOverride : Option [SA ]): (SA , () => Unit ) =
108+ private def startUsingSocketOverride [SA <: SocketAddress ](
109+ route : Route [Identity ],
110+ socketOverride : Option [SA ],
111+ inScopeRunner : InScopeRunner
112+ ): (SA , () => Unit ) =
115113 val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
116114
117115 def unsafeRunF (
118116 callToExecute : () => Identity [ServerResponse [NettyResponse ]]
119117 ): (Future [ServerResponse [NettyResponse ]], () => Future [Unit ]) =
120118 val scalaPromise = Promise [ServerResponse [NettyResponse ]]()
121- val jFuture : JFuture [? ] = executor.submit(new Runnable {
122- override def run (): Unit = try {
123- val result = callToExecute()
124- scalaPromise.success(result)
125- } catch {
126- case NonFatal (e) => scalaPromise.failure(e)
127- }
128- })
119+
120+ // used to cancel the fork, possible states:
121+ // - Cancelled: needed if cancellation happens after the fork starts, but before it is set in the atomic reference
122+ // - CancellableFork: used when cancellation using interruption is enabled
123+ // - None: means that the fork was not set yet
124+ object Cancelled
125+ val runningFork = new AtomicReference [None .type | CancellableFork [Unit ] | Cancelled .type ](None )
126+ // #4747: we are creating forks using the concurrency scope within which the netty server is running
127+ // however, this is called on a Netty-managed thread, hence we need to use the inScopeRunner
128+ inScopeRunner.async {
129+ def run (): Unit =
130+ if runningFork.get() != Cancelled then
131+ try scalaPromise.success(callToExecute())
132+ catch case NonFatal (e) => scalaPromise.failure(e)
133+
134+ if options.interruptServerLogicWhenRequestCancelled then
135+ val forked = forkCancellable(run())
136+ // we only update the state if it's not cancelled already
137+ runningFork.getAndAccumulate(forked, (cur, giv) => if cur != Cancelled then giv else cur) match {
138+ case None => // common "happy path" case
139+ case _ : CancellableFork [Unit ] => throw new IllegalStateException (" Another fork was already set" )
140+ case Cancelled => forked.cancelNow() // cancellation happened before the fork was set
141+ }
142+ else forkDiscard(run())
143+ }
129144
130145 (
131146 scalaPromise.future,
132147 () => {
133- jFuture.cancel(options.interruptServerLogicWhenRequestCancelled)
148+ runningFork.getAndSet(Cancelled ) match {
149+ case fork : CancellableFork [Unit ] => fork.cancelNow()
150+ case _ => // skip - fork not yet set
151+ }
134152 Future .unit
135153 }
136154 )
0 commit comments