diff --git a/src/ws.nim b/src/ws.nim index d255ad3..69d7d7d 100644 --- a/src/ws.nim +++ b/src/ws.nim @@ -14,6 +14,7 @@ type key*: string protocol*: string readyState*: ReadyState + maxPacketSize*: int masked*: bool # send masked packets WebSocketError* = object of IOError @@ -23,6 +24,7 @@ type WebSocketProtocolMismatchError* = object of WebSocketError WebSocketFailedUpgradeError* = object of WebSocketError WebSocketHandshakeError* = object of WebSocketError + WebSocketOversizedPacketError* = object of WebSocketError func newWebSocketClosedError(): auto = newException(WebSocketClosedError, "Socket closed") @@ -97,7 +99,8 @@ proc handshake*(ws: WebSocket, headers: HttpHeaders) {.async.} = proc newWebSocket*( req: Request, - protocol: string = "" + protocol: string = "", + maxPacketSize: int = 1024*1024 ): Future[WebSocket] {.async.} = ## Creates a new socket from a request. try: @@ -108,6 +111,7 @@ proc newWebSocket*( ws.masked = false ws.tcpSocket = req.client ws.protocol = protocol + ws.maxPacketSize = maxPacketSize await ws.handshake(req.headers) return ws @@ -120,13 +124,15 @@ proc newWebSocket*( proc newWebSocket*( url: string, - protocols: seq[string] = @[] + protocols: seq[string] = @[], + maxPacketSize: int = 1024*1024 ): Future[WebSocket] {.async.} = ## Creates a new WebSocket connection, ## protocol is optional, "" means no protocol. var ws = WebSocket() ws.masked = true ws.tcpSocket = newAsyncSocket() + ws.maxPacketSize = maxPacketSize var uri = parseUri(url) var port = Port(9001) @@ -390,6 +396,8 @@ proc recvFrame(ws: WebSocket): Future[Frame] {.async.} = raise newWebSocketClosedError() # Read the data. + if int finalLen > ws.maxPacketSize: + raise newException(WebSocketOversizedPacketError, "Socket attempted to receive a packet larger than maxPacketSize, attempted size: " + $finalLen.int) result.data = await ws.tcpSocket.recv(int finalLen) if result.data.len != int finalLen: raise newWebSocketClosedError()