diff --git a/docs/usage/client.md b/docs/usage/client.md index 58a0200..19839bf 100644 --- a/docs/usage/client.md +++ b/docs/usage/client.md @@ -11,12 +11,17 @@ async def client(): ydoc = Y.YDoc() async with ( connect("ws://localhost:1234/my-roomname") as websocket, - WebsocketProvider(ydoc, websocket), + WebsocketProvider(ydoc, websocket) as provider, ): + ymap = ydoc.get_map("map") + + # Wait until we've received the initial state from the server. + await provider.synced.wait() + print(ymap.to_json()) + # Changes to remote ydoc are applied to local ydoc. # Changes to local ydoc are sent over the WebSocket and # broadcast to all clients. - ymap = ydoc.get_map("map") with ydoc.begin_transaction() as t: ymap.set(t, "key", "value") diff --git a/ypy_websocket/websocket_provider.py b/ypy_websocket/websocket_provider.py index d3454d8..b9f2bfe 100644 --- a/ypy_websocket/websocket_provider.py +++ b/ypy_websocket/websocket_provider.py @@ -17,6 +17,7 @@ from .websocket import Websocket from .yutils import ( YMessageType, + YSyncMessageType, create_update_message, process_sync_message, put_updates, @@ -31,6 +32,7 @@ class WebsocketProvider: _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream _started: Event | None + _synced: Event | None _starting: bool _task_group: TaskGroup | None @@ -63,6 +65,7 @@ def __init__(self, ydoc: Y.YDoc, websocket: Websocket, log: Logger | None = None ) self._started = None self._starting = False + self._synced = None self._task_group = None ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream)) @@ -73,6 +76,13 @@ def started(self) -> Event: self._started = Event() return self._started + @property + def synced(self) -> Event: + """An async event that is set when the WebSocket provider next syncs with the server.""" + if self._synced is None: + self._synced = Event() + return self._synced + async def __aenter__(self) -> WebsocketProvider: if self._task_group is not None: raise RuntimeError("WebsocketProvider already running") @@ -100,6 +110,9 @@ async def _run(self): async for message in self._websocket: if message[0] == YMessageType.SYNC: await process_sync_message(message[1:], self._ydoc, self._websocket, self.log) + if message[1] == YSyncMessageType.SYNC_STEP2 and self._synced is not None: + self._synced.set() + self._synced = None async def _send(self): async with self._update_receive_stream: