Skip to content

Commit f96a4e4

Browse files
authored
Implement ServerRequest.connectionInfo for Netty servers (#4853)
1 parent 3e0f3ad commit f96a4e4

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,42 @@ package sttp.tapir.server.netty
22

33
import scala.collection.JavaConverters._
44
import scala.collection.immutable.Seq
5+
import io.netty.channel.ChannelHandlerContext
56
import io.netty.handler.codec.http.{HttpRequest, QueryStringDecoder}
7+
import io.netty.handler.ssl.SslHandler
68
import sttp.model.{Header, Method, QueryParams, Uri}
79
import sttp.tapir.{AttributeKey, AttributeMap}
810
import sttp.tapir.model.{ConnectionInfo, ServerRequest}
911
import sttp.tapir.server.netty.internal.RichNettyHttpHeaders
1012
import io.netty.handler.codec.http.FullHttpRequest
13+
import java.net.InetSocketAddress
1114

12-
case class NettyServerRequest(req: HttpRequest, attributes: AttributeMap = AttributeMap.Empty) extends ServerRequest {
15+
case class NettyServerRequest(req: HttpRequest, ctx: ChannelHandlerContext, attributes: AttributeMap = AttributeMap.Empty)
16+
extends ServerRequest {
1317
// non-lazy, so that we can validate that the URI parses upfront
1418
override val uri: Uri = Uri.unsafeParse(req.uri())
1519

1620
override lazy val protocol: String = req.protocolVersion().text()
17-
override lazy val connectionInfo: ConnectionInfo = ConnectionInfo.NoInfo
21+
22+
override lazy val connectionInfo: ConnectionInfo = {
23+
val local = ctx.channel().localAddress() match {
24+
case inet: InetSocketAddress => Some(inet)
25+
case _ => None
26+
}
27+
28+
val remote = ctx.channel().remoteAddress() match {
29+
case inet: InetSocketAddress => Some(inet)
30+
case _ => None
31+
}
32+
33+
val secure = uri.scheme match {
34+
case Some("https") | Some("wss") => Some(true)
35+
case Some("http") | Some("ws") => Some(false)
36+
case _ => None
37+
}
38+
39+
ConnectionInfo(local, remote, secure)
40+
}
1841
override lazy val underlying: Any = req
1942
override lazy val queryParameters: QueryParams = {
2043
val decoder = new QueryStringDecoder(req.uri())
@@ -37,5 +60,5 @@ case class NettyServerRequest(req: HttpRequest, attributes: AttributeMap = Attri
3760
override def attribute[T](k: AttributeKey[T]): Option[T] = attributes.get(k)
3861
override def attribute[T](k: AttributeKey[T], v: T): NettyServerRequest = copy(attributes = attributes.put(k, v))
3962
override def withUnderlying(underlying: Any): ServerRequest =
40-
NettyServerRequest(req = underlying.asInstanceOf[HttpRequest], attributes)
63+
NettyServerRequest(req = underlying.asInstanceOf[HttpRequest], ctx = ctx, attributes = attributes)
4164
}

server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class NettyServerHandler[F[_]](
143143
requestTimeoutHandler.foreach(h => ctx.pipeline().addFirst(h))
144144
val (runningFuture, cancellationSwitch) = unsafeRunAsync { () =>
145145
try {
146-
route(NettyServerRequest(req))
146+
route(NettyServerRequest(req, ctx))
147147
.map {
148148
case Some(response) => response
149149
case None => ServerResponse.notFound

server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyServerRequestSpec.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sttp.tapir.server.netty
33
import java.net.{URI => JavaUri}
44

55
import io.netty.buffer.Unpooled
6+
import io.netty.channel.embedded.EmbeddedChannel
67
import io.netty.handler.codec.http._
78
import org.scalatest.freespec.AnyFreeSpec
89
import org.scalatest.matchers.should.Matchers
@@ -28,7 +29,21 @@ class NettyServerRequestSpec extends AnyFreeSpec with Matchers {
2829
trailingHeaders
2930
)
3031

31-
val nettyServerRequest: NettyServerRequest = NettyServerRequest(emptyPostRequest)
32+
// Use EmbeddedChannel for testing - it provides a ChannelHandlerContext
33+
private val embeddedChannel = new EmbeddedChannel()
34+
private var capturedCtx: io.netty.channel.ChannelHandlerContext = null
35+
36+
embeddedChannel
37+
.pipeline()
38+
.addLast(new io.netty.channel.ChannelInboundHandlerAdapter {
39+
override def channelActive(ctx: io.netty.channel.ChannelHandlerContext): Unit = {
40+
capturedCtx = ctx
41+
super.channelActive(ctx)
42+
}
43+
})
44+
.fireChannelActive()
45+
46+
val nettyServerRequest: NettyServerRequest = NettyServerRequest(emptyPostRequest, capturedCtx)
3247

3348
"uri is the same as in request" in {
3449
nettyServerRequest.uri.toString should equal(uri.toString)

0 commit comments

Comments
 (0)