diff --git a/src/pycrdt/websocket/asgi_server.py b/src/pycrdt/websocket/asgi_server.py index be6540f..a69efc5 100644 --- a/src/pycrdt/websocket/asgi_server.py +++ b/src/pycrdt/websocket/asgi_server.py @@ -3,6 +3,7 @@ from collections.abc import Awaitable from inspect import isawaitable from typing import Any, Callable +from urllib.parse import parse_qs from .websocket_server import WebsocketServer @@ -14,11 +15,13 @@ def __init__( send: Callable[[dict[str, Any]], Awaitable[None]], path: str, on_disconnect: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None, + query_params: dict[str, list[str]] | None = None, ): self._receive = receive self._send = send self._path = path self._on_disconnect = on_disconnect + self.query_params = {} if query_params is None else query_params @property def path(self) -> str: @@ -98,5 +101,8 @@ async def __call__( return await send({"type": "websocket.accept"}) - websocket = ASGIWebsocket(receive, send, scope["path"], self._on_disconnect) + query_params = parse_qs(scope["query_string"]) + websocket = ASGIWebsocket( + receive, send, scope["path"], self._on_disconnect, query_params=query_params + ) await self._websocket_server.serve(websocket)