diff --git a/tests/test_ystore.py b/tests/test_ystore.py index 901f39d..bdde180 100644 --- a/tests/test_ystore.py +++ b/tests/test_ystore.py @@ -31,7 +31,7 @@ class MyTempFileYStore(TempFileYStore): class MySQLiteYStore(SQLiteYStore): db_path = MY_SQLITE_YSTORE_DB_PATH - document_ttl = 1000 + document_ttl = 1 def __init__(self, *args, delete_db=False, **kwargs): if delete_db: @@ -61,29 +61,70 @@ async def test_ystore(YStore): assert i == len(data) +async def count_yupdates(db): + """Returns number of yupdates in a SQLite DB given a connection.""" + return (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] + + @pytest.mark.asyncio async def test_document_ttl_sqlite_ystore(test_ydoc): + """Assert that document history is squashed after the document TTL.""" store_name = "my_store" ystore = MySQLiteYStore(store_name, delete_db=True) - now = time.time() for i in range(3): # assert that adding a record before document TTL doesn't delete document history - with patch("time.time") as mock_time: - mock_time.return_value = now - await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: - assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[ - 0 - ] == i + 1 + await ystore.write(test_ydoc.update()) + async with aiosqlite.connect(ystore.db_path) as db: + assert (await count_yupdates(db)) == i + 1 - # assert that adding a record after document TTL deletes previous document history - with patch("time.time") as mock_time: - mock_time.return_value = now + ystore.document_ttl + 1 + await ystore._squash_task + + async with aiosqlite.connect(ystore.db_path) as db: + assert (await count_yupdates(db)) == 1 + + +@pytest.mark.asyncio +async def test_document_ttl_simultaneous_write_sqlite_ystore(test_ydoc): + """Assert that document history is squashed after the document TTL, and a + write that happens at the same time is also squashed.""" + store_name = "my_store" + ystore = MySQLiteYStore(store_name, delete_db=True) + + for i in range(3): await ystore.write(test_ydoc.update()) async with aiosqlite.connect(ystore.db_path) as db: - # two updates in DB: one squashed update and the new update - assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2 + assert (await count_yupdates(db)) == i + 1 + + await asyncio.sleep(ystore.document_ttl) + await ystore.write(test_ydoc.update()) + await ystore._squash_task + + async with aiosqlite.connect(ystore.db_path) as db: + assert (await count_yupdates(db)) == 1 + + +@pytest.mark.asyncio +async def test_document_ttl_init_sqlite_ystore(test_ydoc): + """Assert that document history is squashed on init if the document TTL has + already elapsed since last update.""" + store_name = "my_store" + ystore = MySQLiteYStore(store_name, delete_db=True) + now = time.time() + + with patch("time.time") as mock_time: + mock_time.return_value = now - ystore.document_ttl - 1 + for i in range(3): + await ystore.write(test_ydoc.update()) + async with aiosqlite.connect(ystore.db_path) as db: + assert (await count_yupdates(db)) == i + 1 + + del ystore + ystore = MySQLiteYStore(store_name) + await ystore.db_initialized + + async with aiosqlite.connect(ystore.db_path) as db: + assert (await count_yupdates(db)) == 1 @pytest.mark.asyncio diff --git a/ypy_websocket/ystore.py b/ypy_websocket/ystore.py index 251d23b..0a9f993 100644 --- a/ypy_websocket/ystore.py +++ b/ypy_websocket/ystore.py @@ -177,6 +177,7 @@ def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log= self.metadata_callback = metadata_callback self.log = log or logging.getLogger(__name__) self.db_initialized = asyncio.create_task(self.init_db()) + self._squash_task: Optional[asyncio.Task] = None async def init_db(self): create_db = False @@ -212,6 +213,17 @@ async def init_db(self): await db.execute(f"PRAGMA user_version = {self.version}") await db.commit() + # squash updates if document TTL already elapsed + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + if self.document_ttl is not None and diff > self.document_ttl: + await self._squash() + async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore await self.db_initialized try: @@ -231,36 +243,48 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: igno async def write(self, data: bytes) -> None: await self.db_initialized async with aiosqlite.connect(self.db_path) as db: - # first, determine time elapsed since last update - cursor = await db.execute( - "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", - (self.path,), + # write this update to the DB + metadata = await self.get_metadata() + await db.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, data, metadata, time.time()), ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Y.YDoc() - async with db.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) - ) as cursor: - async for update, in cursor: - Y.apply_update(ydoc, update) - # delete history - await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = Y.encode_state_as_update(ydoc) - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), - ) + await db.commit() + # create task that squashes document history after document_ttl + self._create_squash_task() - # finally, write this update to the DB + async def _squash(self): + """Squashes document history into a single Y update.""" + async with aiosqlite.connect(self.db_path) as db: + # squash updates + ydoc = Y.YDoc() + async with db.execute( + "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) + ) as cursor: + async for update, in cursor: + Y.apply_update(ydoc, update) + # delete history + await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = Y.encode_state_as_update(ydoc) metadata = await self.get_metadata() await db.execute( "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), + (self.path, squashed_update, metadata, time.time()), ) await db.commit() + + async def _squash_later(self): + await asyncio.sleep(self.document_ttl) + await self._squash() + + def _create_squash_task(self) -> None: + """Creates a task that squashes document history after self.document_ttl + and binds it to the _squash_task attribute. If a task already exists, + this cancels the existing task.""" + if self.document_ttl is None: + return + if self._squash_task is not None: + self._squash_task.cancel() + + self._squash_task = asyncio.create_task(self._squash_later())