Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pinned
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2d
stew;https://github.com/status-im/nim-stew@#b66168735d6f3841c5239c3169d3fe5fe98b1257
testutils;https://github.com/status-im/nim-testutils@#9e842bd58420d23044bc55e16088e8abbe93ce51
unittest2;https://github.com/status-im/nim-unittest2@#8b51e99b4a57fcfb31689230e75595f024543024
websock;https://github.com/status-im/nim-websock@#35ae76f1559e835c80f9c1a3943bf995d3dd9eb5
websock;https://github.com/status-im/nim-websock@#f30d4633a761c6615e679de5fa0c0e63460a9ce3
zlib;https://github.com/status-im/nim-zlib@#daa8723fd32299d4ca621c837430c29a5a11e19a
jwt;https://github.com/vacp2p/nim-jwt@#18f8378de52b241f321c1f9ea905456e89b95c6f
bearssl_pkey_decoder;https://github.com/vacp2p/bearssl_pkey_decoder@#21dd3710df9345ed2ad8bf8f882761e07863b8e0
Expand Down
232 changes: 156 additions & 76 deletions libp2p/transports/wstransport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import
../utility,
../stream/connection,
../upgrademngrs/upgrade,
../utils/semaphore,
websock/websock

logScope:
Expand All @@ -34,9 +35,10 @@ logScope:
export transport, websock, results

const
DefaultHeadersTimeout = 3.seconds
DefaultHandshakeTimeout = 3.seconds
DefaultAutotlsWaitTimeout = 3.seconds
DefaultAutotlsRetries = 3
DefaultConcurrentAccepts = 200

type
WsStream = ref object of Connection
Expand Down Expand Up @@ -111,11 +113,17 @@ method closeImpl*(s: WsStream): Future[void] {.async: (raises: []).} =
method getWrapped*(s: WsStream): Connection =
nil

type AcceptResult = Result[Connection, ref CatchableError]

type WsTransport* = ref object of Transport
httpservers: seq[HttpServer]
wsserver: WSServer
connections: array[Direction, seq[WsStream]]
acceptFuts: seq[Future[HttpRequest]]
handshakeFuts: seq[Future[void]]
acceptResults: AsyncQueue[AcceptResult]
acceptLoop: Future[void]
acceptSem: AsyncSemaphore
concurrentAccepts: int

tlsPrivateKey*: TLSPrivateKey
tlsCertificate*: TLSCertificate
Expand All @@ -129,6 +137,117 @@ type WsTransport* = ref object of Transport
proc secure*(self: WsTransport): bool =
not (isNil(self.tlsPrivateKey) or isNil(self.tlsCertificate))

proc connHandler(
self: WsTransport, stream: WSSession, secure: bool, dir: Direction
): Future[Connection] {.async: (raises: [CatchableError]).} =
let (observedAddr, localAddr) =
try:
let
codec =
if secure:
MultiAddress.init("/wss")
else:
MultiAddress.init("/ws")
remoteAddr = stream.stream.reader.tsource.remoteAddress
localAddr = stream.stream.reader.tsource.localAddress

(
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(),
MultiAddress.init(localAddr).tryGet() & codec.tryGet(),
)
except CatchableError as exc:
trace "Failed to create observedAddr or listenAddr", description = exc.msg
if not (isNil(stream) and stream.stream.reader.closed):
safeClose(stream)
raise exc

let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr))

self.connections[dir].add(conn)
proc onClose() {.async: (raises: []).} =
await noCancel conn.session.stream.reader.join()
self.connections[dir].keepItIf(it != conn)
trace "Cleaned up client"

asyncSpawn onClose()
return conn

proc addHandshakeResult(self: WsTransport, ares: AcceptResult) =
try:
self.acceptResults.addLastNoWait(ares)
except AsyncQueueFullError: # never happens but need to catch
discard

proc handshakeWorker(
self: WsTransport, server: HttpServer, clientStream: AsyncStream
) {.async: (raises: []).} =
try:
let conn = await (
proc(): Future[Connection] {.async.} =
let req = await server.processHttpRequest(clientStream)
let wstransp = await self.wsserver.handleRequest(req)
return await self.connHandler(wstransp, server.secure, Direction.In)
)()
.wait(self.handshakeTimeout)
self.addHandshakeResult(AcceptResult.ok(conn))
except CatchableError as exc:
await noCancel clientStream.closeWait()
self.addHandshakeResult(AcceptResult.err(exc))
finally:
self.acceptSem.release()

proc acceptDispatcher(self: WsTransport) {.async: (raises: []).} =
trace "Started acceptDispatcher"

var acceptFuts: seq[Future[AsyncStream]] = @[]
for server in self.httpservers:
acceptFuts.add(server.acceptStream())
if acceptFuts.len == 0:
error "acceptDispatcher has no work; terminating"
return

while self.running:
try:
if self.handshakeFuts.len > 0:
self.handshakeFuts.keepItIf(not it.finished)
await self.acceptSem.acquire()
except CancelledError:
continue
try:
let streamFut = await one(acceptFuts)
let idx = acceptFuts.find(streamFut)
if idx < 0:
self.acceptSem.release()
continue

let httpServer = self.httpservers[idx]
acceptFuts[idx] = httpServer.acceptStream()

if streamFut.failed:
self.acceptSem.release()
self.addHandshakeResult(AcceptResult.err(streamFut.error))
continue

let hFut = self.handshakeWorker(httpServer, streamFut.read())
self.handshakeFuts.add(hFut)
except CatchableError as exc:
self.acceptSem.release()
if not self.running:
break
trace "Error in acceptDispatcher", msg = exc.msg
try:
await sleepAsync(100.milliseconds)
except CancelledError:
discard

trace "Exiting acceptDispatcher"
for fut in acceptFuts:
if not fut.finished:
await fut.cancelAndWait()
self.addHandshakeResult(
AcceptResult.err(newException(TransportClosedError, "Server is closed"))
)

method start*(
self: WsTransport, addrs: seq[MultiAddress]
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
Expand Down Expand Up @@ -175,6 +294,9 @@ method start*(

let address = ma.initTAddress().tryGet()

# allow HTTP headers to take up to 90% of the WS handshake's total time budget
let headerProcessingTimeout = self.handshakeTimeout * 9 div 10

let httpserver =
try:
if isWss:
Expand All @@ -183,10 +305,10 @@ method start*(
tlsPrivateKey = self.tlsPrivateKey,
tlsCertificate = self.tlsCertificate,
flags = self.flags,
handshakeTimeout = self.handshakeTimeout,
headersTimeout = headerProcessingTimeout,
)
else:
HttpServer.create(address, handshakeTimeout = self.handshakeTimeout)
HttpServer.create(address, headersTimeout = headerProcessingTimeout)
except CatchableError as exc:
raise (ref WsTransportError)(
msg: "error in WsTransport start: " & exc.msg, parent: exc
Expand All @@ -209,6 +331,10 @@ method start*(

trace "Listening on", addresses = self.addrs

self.acceptSem = newAsyncSemaphore(self.concurrentAccepts)
self.acceptResults = newAsyncQueue[AcceptResult]()
self.acceptLoop = self.acceptDispatcher()

method stop*(self: WsTransport) {.async: (raises: []).} =
## stop the transport
##
Expand All @@ -224,61 +350,27 @@ method stop*(self: WsTransport) {.async: (raises: []).} =
self.connections[Direction.Out].mapIt(it.close())
)

if not isNil(self.acceptLoop):
await self.acceptLoop.cancelAndWait()

var toWait: seq[Future[void]]
for fut in self.acceptFuts:
if not fut.finished:
toWait.add(fut.cancelAndWait())
elif fut.completed:
toWait.add(fut.read().stream.closeWait())

for server in self.httpservers:
server.stop()
toWait.add(server.closeWait())

for fut in self.handshakeFuts:
if not fut.finished:
fut.cancel()
toWait.add(self.handshakeFuts)

await allFutures(toWait)

self.httpservers = @[]
trace "Transport stopped"
except CatchableError as exc:
trace "Error shutting down ws transport", description = exc.msg

proc connHandler(
self: WsTransport, stream: WSSession, secure: bool, dir: Direction
): Future[Connection] {.async: (raises: [CatchableError]).} =
## Returning CatchableError is fine because we later handle different exceptions.

let (observedAddr, localAddr) =
try:
let
codec =
if secure:
MultiAddress.init("/wss")
else:
MultiAddress.init("/ws")
remoteAddr = stream.stream.reader.tsource.remoteAddress
localAddr = stream.stream.reader.tsource.localAddress

(
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(),
MultiAddress.init(localAddr).tryGet() & codec.tryGet(),
)
except CatchableError as exc:
trace "Failed to create observedAddr or listenAddr", description = exc.msg
if not (isNil(stream) and stream.stream.reader.closed):
safeClose(stream)
raise exc

let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr))

self.connections[dir].add(conn)
proc onClose() {.async: (raises: []).} =
await noCancel conn.session.stream.reader.join()
self.connections[dir].keepItIf(it != conn)
trace "Cleaned up client"

asyncSpawn onClose()
return conn

method accept*(
self: WsTransport
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
Expand All @@ -294,34 +386,12 @@ method accept*(
if not self.running:
raise newTransportClosedError()

if self.acceptFuts.len <= 0:
self.acceptFuts = self.httpservers.mapIt(it.accept())

if self.acceptFuts.len <= 0:
return

let finished =
try:
await one(self.acceptFuts)
except ValueError:
raiseAssert("already checked with if")
except CancelledError as e:
raise e

let index = self.acceptFuts.find(finished)
self.acceptFuts[index] = self.httpservers[index].accept()
let res = await self.acceptResults.popFirst()
res.isErrOr:
return value

try:
let req = await finished

try:
let wstransp = await self.wsserver.handleRequest(req).wait(self.handshakeTimeout)
let isSecure = self.httpservers[index].secure

return await self.connHandler(wstransp, isSecure, Direction.In)
except CatchableError as exc:
await noCancel req.stream.closeWait()
raise exc
raise res.error
except WebSocketError as exc:
debug "Websocket Error", description = exc.msg
except HttpError as exc:
Expand All @@ -334,13 +404,17 @@ method accept*(
debug "Connection aborted", description = exc.msg
except AsyncTimeoutError as exc:
debug "Timed out", description = exc.msg
except TransportOsError as exc:
debug "OS Error", description = exc.msg
except TransportUseClosedError as exc:
debug "Server was closed", description = exc.msg
raise newTransportClosedError(exc)
except TransportClosedError as exc:
self.addHandshakeResult(res)
debug "Server was closed", description = exc.msg
raise newTransportClosedError(exc)
except CancelledError as exc:
raise exc
except TransportOsError as exc:
debug "OS Error", description = exc.msg
except CatchableError as exc:
info "Unexpected error accepting connection", description = exc.msg
raise newException(
Expand Down Expand Up @@ -392,9 +466,11 @@ proc new*(
flags: set[ServerFlags] = {},
factories: openArray[ExtFactory] = [],
rng: ref HmacDrbgContext = nil,
handshakeTimeout = DefaultHeadersTimeout,
handshakeTimeout = DefaultHandshakeTimeout,
concurrentAccepts = DefaultConcurrentAccepts,
): T {.raises: [].} =
## Creates a secure WebSocket transport
doAssert concurrentAccepts > 0, "must accept connections"

let self = T(
upgrader: upgrade,
Expand All @@ -406,6 +482,7 @@ proc new*(
factories: @factories,
rng: rng,
handshakeTimeout: handshakeTimeout,
concurrentAccepts: concurrentAccepts,
)
procCall Transport(self).initialize()
self
Expand All @@ -416,9 +493,11 @@ proc new*(
flags: set[ServerFlags] = {},
factories: openArray[ExtFactory] = [],
rng: ref HmacDrbgContext = nil,
handshakeTimeout = DefaultHeadersTimeout,
handshakeTimeout = DefaultHandshakeTimeout,
concurrentAccepts = DefaultConcurrentAccepts,
): T {.raises: [].} =
## Creates a clear-text WebSocket transport
doAssert concurrentAccepts > 0, "must accept connections"

T.new(
upgrade = upgrade,
Expand All @@ -429,4 +508,5 @@ proc new*(
factories = @factories,
rng = rng,
handshakeTimeout = handshakeTimeout,
concurrentAccepts = concurrentAccepts,
)
14 changes: 12 additions & 2 deletions tests/libp2p/transports/stream_tests.nim
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ template streamTransportTest*(
const chunkSize = 64
const chunkCount = 32
const messageSize = chunkSize * chunkCount
const errorClientId: byte = 0xff
const numConnections = 5
doAssert numConnections < errorClientId
var serverReadOrder: seq[byte] = @[]

# Track when stream handlers complete
Expand Down Expand Up @@ -479,10 +481,18 @@ template streamTransportTest*(
# Doing this improves likelihood of parallel data transition on the connections.
await sleepAsync(rand(20 .. 100).milliseconds)

check receivedData == newData(messageSize, byte(handlerIndex))
let
# Get the client ID from any byte of the data; can't depend on accept/dial order.
clientId =
if receivedData.len > 0:
receivedData[0]
else:
errorClientId

check receivedData == newData(messageSize, clientId)

# Send back ID
await stream.write(@[byte(receivedData[0])])
await stream.write(@[clientId])

# Signal that this stream handler is done
serverStreamHandlerFuts[handlerIndex].complete()
Expand Down
Loading