From db3dc39ce1aa1cb15b57f80871ad91cbb8c34eac Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Sun, 1 Jun 2025 18:12:40 +0200 Subject: [PATCH] feat(checkpoint-redis): implement adelete_thread and delete_thread methods (#51) - Add adelete_thread method to AsyncRedisSaver to delete all checkpoints, blobs, and writes for a thread - Add delete_thread method to RedisSaver for sync operations - Use Redis search indices instead of keys() command for better performance - Batch deletions using Redis pipeline for efficiency --- langgraph/checkpoint/redis/__init__.py | 72 ++++++++ langgraph/checkpoint/redis/aio.py | 72 ++++++++ tests/test_issue_51_adelete_thread.py | 240 +++++++++++++++++++++++++ 3 files changed, 384 insertions(+) create mode 100644 tests/test_issue_51_adelete_thread.py diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index 2c0f123..5851172 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -575,6 +575,78 @@ def _load_pending_sends( # Extract type and blob pairs return [(doc.type, doc.blob) for doc in sorted_writes] + def delete_thread(self, thread_id: str) -> None: + """Delete all checkpoints and writes associated with a specific thread ID. + + Args: + thread_id: The thread ID whose checkpoints should be deleted. + """ + storage_safe_thread_id = to_storage_safe_id(thread_id) + + # Delete all checkpoints for this thread + checkpoint_query = FilterQuery( + filter_expression=Tag("thread_id") == storage_safe_thread_id, + return_fields=["checkpoint_ns", "checkpoint_id"], + num_results=10000, # Get all checkpoints for this thread + ) + + checkpoint_results = self.checkpoints_index.search(checkpoint_query) + + # Delete all checkpoint-related keys + pipeline = self._redis.pipeline() + + for doc in checkpoint_results.docs: + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + checkpoint_id = getattr(doc, "checkpoint_id", "") + + # Delete checkpoint key + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + storage_safe_thread_id, checkpoint_ns, checkpoint_id + ) + pipeline.delete(checkpoint_key) + + # Delete all blobs for this thread + blob_query = FilterQuery( + filter_expression=Tag("thread_id") == storage_safe_thread_id, + return_fields=["checkpoint_ns", "channel", "version"], + num_results=10000, + ) + + blob_results = self.checkpoint_blobs_index.search(blob_query) + + for doc in blob_results.docs: + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + channel = getattr(doc, "channel", "") + version = getattr(doc, "version", "") + + blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key( + storage_safe_thread_id, checkpoint_ns, channel, version + ) + pipeline.delete(blob_key) + + # Delete all writes for this thread + writes_query = FilterQuery( + filter_expression=Tag("thread_id") == storage_safe_thread_id, + return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"], + num_results=10000, + ) + + writes_results = self.checkpoint_writes_index.search(writes_query) + + for doc in writes_results.docs: + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + checkpoint_id = getattr(doc, "checkpoint_id", "") + task_id = getattr(doc, "task_id", "") + idx = getattr(doc, "idx", 0) + + write_key = BaseRedisSaver._make_redis_checkpoint_writes_key( + storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx + ) + pipeline.delete(write_key) + + # Execute all deletions + pipeline.execute() + __all__ = [ "__version__", diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index e157ee1..46a4d25 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -925,3 +925,75 @@ async def _aload_pending_writes( pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict) return pending_writes + + async def adelete_thread(self, thread_id: str) -> None: + """Delete all checkpoints and writes associated with a specific thread ID. + + Args: + thread_id: The thread ID whose checkpoints should be deleted. + """ + storage_safe_thread_id = to_storage_safe_id(thread_id) + + # Delete all checkpoints for this thread + checkpoint_query = FilterQuery( + filter_expression=Tag("thread_id") == storage_safe_thread_id, + return_fields=["checkpoint_ns", "checkpoint_id"], + num_results=10000, # Get all checkpoints for this thread + ) + + checkpoint_results = await self.checkpoints_index.search(checkpoint_query) + + # Delete all checkpoint-related keys + pipeline = self._redis.pipeline() + + for doc in checkpoint_results.docs: + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + checkpoint_id = getattr(doc, "checkpoint_id", "") + + # Delete checkpoint key + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + storage_safe_thread_id, checkpoint_ns, checkpoint_id + ) + pipeline.delete(checkpoint_key) + + # Delete all blobs for this thread + blob_query = FilterQuery( + filter_expression=Tag("thread_id") == storage_safe_thread_id, + return_fields=["checkpoint_ns", "channel", "version"], + num_results=10000, + ) + + blob_results = await self.checkpoint_blobs_index.search(blob_query) + + for doc in blob_results.docs: + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + channel = getattr(doc, "channel", "") + version = getattr(doc, "version", "") + + blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key( + storage_safe_thread_id, checkpoint_ns, channel, version + ) + pipeline.delete(blob_key) + + # Delete all writes for this thread + writes_query = FilterQuery( + filter_expression=Tag("thread_id") == storage_safe_thread_id, + return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"], + num_results=10000, + ) + + writes_results = await self.checkpoint_writes_index.search(writes_query) + + for doc in writes_results.docs: + checkpoint_ns = getattr(doc, "checkpoint_ns", "") + checkpoint_id = getattr(doc, "checkpoint_id", "") + task_id = getattr(doc, "task_id", "") + idx = getattr(doc, "idx", 0) + + write_key = BaseRedisSaver._make_redis_checkpoint_writes_key( + storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx + ) + pipeline.delete(write_key) + + # Execute all deletions + await pipeline.execute() diff --git a/tests/test_issue_51_adelete_thread.py b/tests/test_issue_51_adelete_thread.py new file mode 100644 index 0000000..f367630 --- /dev/null +++ b/tests/test_issue_51_adelete_thread.py @@ -0,0 +1,240 @@ +"""Test for issue #51 - adelete_thread implementation.""" + +import pytest +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata + +from langgraph.checkpoint.redis import RedisSaver +from langgraph.checkpoint.redis.aio import AsyncRedisSaver + + +@pytest.mark.asyncio +async def test_adelete_thread_implemented(redis_url): + """Test that adelete_thread method is now implemented in AsyncRedisSaver.""" + async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: + # Create a checkpoint + thread_id = "test-thread-to-delete" + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": "", + "checkpoint_id": "1", + } + } + + checkpoint = Checkpoint( + v=1, + id="1", + ts="2024-01-01T00:00:00Z", + channel_values={"messages": ["Test"]}, + channel_versions={"messages": "1"}, + versions_seen={"agent": {"messages": "1"}}, + pending_sends=[], + tasks=[], + ) + + # Store checkpoint + await checkpointer.aput( + config=config, + checkpoint=checkpoint, + metadata=CheckpointMetadata(source="input", step=0, writes={}), + new_versions={"messages": "1"}, + ) + + # Verify checkpoint exists + result = await checkpointer.aget_tuple(config) + assert result is not None + assert result.checkpoint["id"] == "1" + + # Delete the thread + await checkpointer.adelete_thread(thread_id) + + # Verify checkpoint is deleted + result = await checkpointer.aget_tuple(config) + assert result is None + + +def test_delete_thread_implemented(redis_url): + """Test that delete_thread method is now implemented in RedisSaver.""" + with RedisSaver.from_conn_string(redis_url) as checkpointer: + checkpointer.setup() # Initialize Redis indices + + # Create a checkpoint + thread_id = "test-thread-to-delete-sync" + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": "", + "checkpoint_id": "1", + } + } + + checkpoint = Checkpoint( + v=1, + id="1", + ts="2024-01-01T00:00:00Z", + channel_values={"messages": ["Test"]}, + channel_versions={"messages": "1"}, + versions_seen={"agent": {"messages": "1"}}, + pending_sends=[], + tasks=[], + ) + + # Store checkpoint + checkpointer.put( + config=config, + checkpoint=checkpoint, + metadata=CheckpointMetadata(source="input", step=0, writes={}), + new_versions={"messages": "1"}, + ) + + # Verify checkpoint exists + result = checkpointer.get_tuple(config) + assert result is not None + assert result.checkpoint["id"] == "1" + + # Delete the thread + checkpointer.delete_thread(thread_id) + + # Verify checkpoint is deleted + result = checkpointer.get_tuple(config) + assert result is None + + +@pytest.mark.asyncio +async def test_adelete_thread_comprehensive(redis_url): + """Comprehensive test for adelete_thread with multiple checkpoints and namespaces.""" + async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: + thread_id = "test-thread-comprehensive" + other_thread_id = "other-thread" + + # Create multiple checkpoints for the thread + checkpoints_data = [ + ("", "1", {"messages": ["First"]}, "input", 0), + ("", "2", {"messages": ["Second"]}, "output", 1), + ("ns1", "3", {"messages": ["Third"]}, "input", 0), + ("ns2", "4", {"messages": ["Fourth"]}, "output", 1), + ] + + # Also create checkpoints for another thread that should not be deleted + other_checkpoints_data = [ + ("", "5", {"messages": ["Other1"]}, "input", 0), + ("ns1", "6", {"messages": ["Other2"]}, "output", 1), + ] + + # Store all checkpoints + for ns, cp_id, channel_values, source, step in checkpoints_data: + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": ns, + "checkpoint_id": cp_id, + } + } + + checkpoint = Checkpoint( + v=1, + id=cp_id, + ts=f"2024-01-01T00:00:0{cp_id}Z", + channel_values=channel_values, + channel_versions={"messages": "1"}, + versions_seen={"agent": {"messages": "1"}}, + pending_sends=[], + tasks=[], + ) + + await checkpointer.aput( + config=config, + checkpoint=checkpoint, + metadata=CheckpointMetadata(source=source, step=step, writes={}), + new_versions={"messages": "1"}, + ) + + # Also add some writes + await checkpointer.aput_writes( + config=config, + writes=[("messages", f"Write for {cp_id}")], + task_id=f"task-{cp_id}", + ) + + # Store checkpoints for other thread + for ns, cp_id, channel_values, source, step in other_checkpoints_data: + config: RunnableConfig = { + "configurable": { + "thread_id": other_thread_id, + "checkpoint_ns": ns, + "checkpoint_id": cp_id, + } + } + + checkpoint = Checkpoint( + v=1, + id=cp_id, + ts=f"2024-01-01T00:00:0{cp_id}Z", + channel_values=channel_values, + channel_versions={"messages": "1"}, + versions_seen={"agent": {"messages": "1"}}, + pending_sends=[], + tasks=[], + ) + + await checkpointer.aput( + config=config, + checkpoint=checkpoint, + metadata=CheckpointMetadata(source=source, step=step, writes={}), + new_versions={"messages": "1"}, + ) + + # Verify all checkpoints exist + for ns, cp_id, _, _, _ in checkpoints_data: + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": ns, + "checkpoint_id": cp_id, + } + } + result = await checkpointer.aget_tuple(config) + assert result is not None + assert result.checkpoint["id"] == cp_id + + # Verify other thread checkpoints exist + for ns, cp_id, _, _, _ in other_checkpoints_data: + config: RunnableConfig = { + "configurable": { + "thread_id": other_thread_id, + "checkpoint_ns": ns, + "checkpoint_id": cp_id, + } + } + result = await checkpointer.aget_tuple(config) + assert result is not None + assert result.checkpoint["id"] == cp_id + + # Delete the thread + await checkpointer.adelete_thread(thread_id) + + # Verify all checkpoints for the thread are deleted + for ns, cp_id, _, _, _ in checkpoints_data: + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": ns, + "checkpoint_id": cp_id, + } + } + result = await checkpointer.aget_tuple(config) + assert result is None + + # Verify other thread checkpoints still exist + for ns, cp_id, _, _, _ in other_checkpoints_data: + config: RunnableConfig = { + "configurable": { + "thread_id": other_thread_id, + "checkpoint_ns": ns, + "checkpoint_id": cp_id, + } + } + result = await checkpointer.aget_tuple(config) + assert result is not None + assert result.checkpoint["id"] == cp_id