Skip to content

Commit 74fa0ae

Browse files
committed
fix: apply HIL pending_sends fixes to all implementations
This commit completes PR #57 by applying the same fixes to all four implementations of the pending_sends loading methods: - RedisSaver._load_pending_sends (sync) - ShallowRedisSaver._load_pending_sends (shallow sync) - AsyncShallowRedisSaver._aload_pending_sends (shallow async) The fixes include: 1. Using $.blob instead of blob for JSON path access 2. Using raw checkpoint_ns without to_storage_safe_str conversion 3. Encoding type to bytes with proper error handling for missing blobs Also adds comprehensive test coverage for Human-in-the-Loop functionality across all implementations.
1 parent 73d73d6 commit 74fa0ae

File tree

4 files changed

+319
-10
lines changed

4 files changed

+319
-10
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -576,17 +576,16 @@ def _load_pending_sends(
576576
Returns:
577577
List of (type, blob) tuples representing pending sends
578578
"""
579-
storage_safe_thread_id = to_storage_safe_str(thread_id)
580-
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
581-
storage_safe_parent_checkpoint_id = to_storage_safe_str(parent_checkpoint_id)
579+
storage_safe_thread_id = to_storage_safe_id(thread_id)
580+
storage_safe_parent_checkpoint_id = to_storage_safe_id(parent_checkpoint_id)
582581

583582
# Query checkpoint_writes for parent checkpoint's TASKS channel
584583
parent_writes_query = FilterQuery(
585584
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
586-
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
585+
& (Tag("checkpoint_ns") == checkpoint_ns)
587586
& (Tag("checkpoint_id") == storage_safe_parent_checkpoint_id)
588587
& (Tag("channel") == TASKS),
589-
return_fields=["type", "blob", "task_path", "task_id", "idx"],
588+
return_fields=["type", "$.blob", "task_path", "task_id", "idx"],
590589
num_results=100, # Adjust as needed
591590
)
592591
parent_writes_results = self.checkpoint_writes_index.search(parent_writes_query)
@@ -602,7 +601,11 @@ def _load_pending_sends(
602601
)
603602

604603
# Extract type and blob pairs
605-
return [(doc.type, doc.blob) for doc in sorted_writes]
604+
return [
605+
(doc.type.encode(), blob)
606+
for doc in sorted_writes
607+
if (blob := getattr(doc, "$.blob", getattr(doc, "blob", None))) is not None
608+
]
606609

607610
def delete_thread(self, thread_id: str) -> None:
608611
"""Delete all checkpoints and writes associated with a specific thread ID.

langgraph/checkpoint/redis/ashallow.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ async def _aload_pending_sends(
676676
filter_expression=(Tag("thread_id") == thread_id)
677677
& (Tag("checkpoint_ns") == checkpoint_ns)
678678
& (Tag("channel") == TASKS),
679-
return_fields=["type", "blob", "task_path", "task_id", "idx"],
679+
return_fields=["type", "$.blob", "task_path", "task_id", "idx"],
680680
num_results=100,
681681
)
682682
parent_writes_results = await self.checkpoint_writes_index.search(
@@ -694,7 +694,11 @@ async def _aload_pending_sends(
694694
)
695695

696696
# Extract type and blob pairs
697-
return [(doc.type, doc.blob) for doc in sorted_writes]
697+
return [
698+
(doc.type.encode(), blob)
699+
for doc in sorted_writes
700+
if (blob := getattr(doc, "$.blob", getattr(doc, "blob", None))) is not None
701+
]
698702

699703
async def _aload_pending_writes(
700704
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str

langgraph/checkpoint/redis/shallow.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def _load_pending_sends(
675675
filter_expression=(Tag("thread_id") == thread_id)
676676
& (Tag("checkpoint_ns") == checkpoint_ns)
677677
& (Tag("channel") == TASKS),
678-
return_fields=["type", "blob", "task_path", "task_id", "idx"],
678+
return_fields=["type", "$.blob", "task_path", "task_id", "idx"],
679679
num_results=100,
680680
)
681681
parent_writes_results = self.checkpoint_writes_index.search(parent_writes_query)
@@ -691,7 +691,11 @@ def _load_pending_sends(
691691
)
692692

693693
# Extract type and blob pairs
694-
return [(doc.type, doc.blob) for doc in sorted_writes]
694+
return [
695+
(doc.type.encode(), blob)
696+
for doc in sorted_writes
697+
if (blob := getattr(doc, "$.blob", getattr(doc, "blob", None))) is not None
698+
]
695699

696700
@staticmethod
697701
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
"""Test Human-in-the-Loop pending_sends functionality across all implementations."""
2+
3+
import asyncio
4+
import json
5+
from typing import Any, Dict
6+
7+
import pytest
8+
from langchain_core.runnables import RunnableConfig
9+
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, create_checkpoint
10+
from langgraph.constants import TASKS
11+
from redisvl.redis.connection import RedisConnectionFactory
12+
13+
from langgraph.checkpoint.redis import AsyncRedisSaver, RedisSaver
14+
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
15+
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver
16+
17+
18+
def create_test_checkpoint() -> Checkpoint:
19+
"""Create a test checkpoint for HIL scenarios."""
20+
return {
21+
"v": 1,
22+
"id": "test_checkpoint_1",
23+
"ts": "2024-01-01T00:00:00+00:00",
24+
"channel_values": {},
25+
"channel_versions": {},
26+
"versions_seen": {},
27+
"pending_sends": [],
28+
}
29+
30+
31+
def create_hil_task_writes() -> list[tuple[str, Any]]:
32+
"""Create test writes that simulate HIL task submissions."""
33+
return [
34+
(TASKS, {"task": "review_document", "args": {"doc_id": "123"}}),
35+
(TASKS, {"task": "approve_action", "args": {"action": "deploy"}}),
36+
(TASKS, {"task": "human_feedback", "args": {"prompt": "Continue?"}}),
37+
]
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_async_redis_saver_hil_pending_sends(redis_url: str):
42+
"""Test AsyncRedisSaver._aload_pending_sends for HIL workflows."""
43+
async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
44+
thread_id = "test-hil-thread-async"
45+
checkpoint_ns = "test-namespace"
46+
parent_checkpoint_id = "parent-checkpoint-1"
47+
48+
# Create parent config
49+
parent_config: RunnableConfig = {
50+
"configurable": {
51+
"thread_id": thread_id,
52+
"checkpoint_ns": checkpoint_ns,
53+
"checkpoint_id": parent_checkpoint_id,
54+
}
55+
}
56+
57+
# Create parent checkpoint
58+
parent_checkpoint = create_test_checkpoint()
59+
parent_checkpoint["id"] = parent_checkpoint_id
60+
metadata: CheckpointMetadata = {"source": "input", "step": 1}
61+
62+
# Save parent checkpoint
63+
await saver.aput(parent_config, parent_checkpoint, metadata, {})
64+
65+
# Write HIL tasks
66+
hil_writes = create_hil_task_writes()
67+
await saver.aput_writes(parent_config, hil_writes, task_id="hil-task-1")
68+
69+
# Load pending sends - this is where the bug would occur
70+
pending_sends = await saver._aload_pending_sends(
71+
thread_id=thread_id,
72+
checkpoint_ns=checkpoint_ns,
73+
parent_checkpoint_id=parent_checkpoint_id,
74+
)
75+
76+
# Verify we got the correct pending sends
77+
assert len(pending_sends) == 3
78+
assert all(isinstance(send[0], bytes) for send in pending_sends)
79+
# Blob can be bytes or str depending on how Redis stores it
80+
assert all(isinstance(send[1], (bytes, str)) for send in pending_sends)
81+
82+
# Verify the content
83+
for i, (type_bytes, blob_bytes) in enumerate(pending_sends):
84+
type_str = type_bytes.decode()
85+
# Type could be json or msgpack depending on serde config
86+
assert type_str in ["json", "msgpack"]
87+
88+
# The blob should contain our task data
89+
assert blob_bytes is not None
90+
assert len(blob_bytes) > 0
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_sync_redis_saver_hil_pending_sends(redis_url: str):
95+
"""Test RedisSaver._load_pending_sends for HIL workflows."""
96+
with RedisSaver.from_conn_string(redis_url) as saver:
97+
saver.setup()
98+
thread_id = "test-hil-thread-sync"
99+
checkpoint_ns = "test-namespace"
100+
parent_checkpoint_id = "parent-checkpoint-2"
101+
102+
# Create parent config
103+
parent_config: RunnableConfig = {
104+
"configurable": {
105+
"thread_id": thread_id,
106+
"checkpoint_ns": checkpoint_ns,
107+
"checkpoint_id": parent_checkpoint_id,
108+
}
109+
}
110+
111+
# Create parent checkpoint
112+
parent_checkpoint = create_test_checkpoint()
113+
parent_checkpoint["id"] = parent_checkpoint_id
114+
metadata: CheckpointMetadata = {"source": "input", "step": 1}
115+
116+
# Save parent checkpoint
117+
saver.put(parent_config, parent_checkpoint, metadata, {})
118+
119+
# Write HIL tasks
120+
hil_writes = create_hil_task_writes()
121+
saver.put_writes(parent_config, hil_writes, task_id="hil-task-2")
122+
123+
# Load pending sends - this is where the bug would occur
124+
pending_sends = saver._load_pending_sends(
125+
thread_id=thread_id,
126+
checkpoint_ns=checkpoint_ns,
127+
parent_checkpoint_id=parent_checkpoint_id,
128+
)
129+
130+
# Verify we got the correct pending sends
131+
assert len(pending_sends) == 3
132+
assert all(isinstance(send[0], bytes) for send in pending_sends)
133+
# Blob can be bytes or str depending on how Redis stores it
134+
assert all(isinstance(send[1], (bytes, str)) for send in pending_sends)
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_async_shallow_saver_hil_pending_sends(redis_url: str):
139+
"""Test AsyncShallowRedisSaver._aload_pending_sends for HIL workflows."""
140+
async with AsyncShallowRedisSaver.from_conn_string(redis_url) as saver:
141+
thread_id = "test-hil-thread-async-shallow"
142+
checkpoint_ns = "test-namespace"
143+
144+
# Create config
145+
config: RunnableConfig = {
146+
"configurable": {
147+
"thread_id": thread_id,
148+
"checkpoint_ns": checkpoint_ns,
149+
"checkpoint_id": "checkpoint-1",
150+
}
151+
}
152+
153+
# Create checkpoint
154+
checkpoint = create_test_checkpoint()
155+
metadata: CheckpointMetadata = {"source": "input", "step": 1}
156+
157+
# Save checkpoint
158+
await saver.aput(config, checkpoint, metadata, {})
159+
160+
# Write HIL tasks
161+
hil_writes = create_hil_task_writes()
162+
await saver.aput_writes(config, hil_writes, task_id="hil-task-3")
163+
164+
# Load pending sends
165+
pending_sends = await saver._aload_pending_sends(
166+
thread_id=thread_id,
167+
checkpoint_ns=checkpoint_ns,
168+
)
169+
170+
# Verify we got the correct pending sends
171+
assert len(pending_sends) == 3
172+
assert all(isinstance(send[0], bytes) for send in pending_sends)
173+
# Blob can be bytes or str depending on how Redis stores it
174+
assert all(isinstance(send[1], (bytes, str)) for send in pending_sends)
175+
176+
177+
def test_sync_shallow_saver_hil_pending_sends(redis_url: str):
178+
"""Test ShallowRedisSaver._load_pending_sends for HIL workflows."""
179+
with ShallowRedisSaver.from_conn_string(redis_url) as saver:
180+
saver.setup()
181+
thread_id = "test-hil-thread-sync-shallow"
182+
checkpoint_ns = "test-namespace"
183+
184+
# Create config
185+
config: RunnableConfig = {
186+
"configurable": {
187+
"thread_id": thread_id,
188+
"checkpoint_ns": checkpoint_ns,
189+
"checkpoint_id": "checkpoint-2",
190+
}
191+
}
192+
193+
# Create checkpoint
194+
checkpoint = create_test_checkpoint()
195+
metadata: CheckpointMetadata = {"source": "input", "step": 1}
196+
197+
# Save checkpoint
198+
saver.put(config, checkpoint, metadata, {})
199+
200+
# Write HIL tasks
201+
hil_writes = create_hil_task_writes()
202+
saver.put_writes(config, hil_writes, task_id="hil-task-4")
203+
204+
# Load pending sends
205+
pending_sends = saver._load_pending_sends(
206+
thread_id=thread_id,
207+
checkpoint_ns=checkpoint_ns,
208+
)
209+
210+
# Verify we got the correct pending sends
211+
assert len(pending_sends) == 3
212+
assert all(isinstance(send[0], bytes) for send in pending_sends)
213+
# Blob can be bytes or str depending on how Redis stores it
214+
assert all(isinstance(send[1], (bytes, str)) for send in pending_sends)
215+
216+
217+
@pytest.mark.asyncio
218+
async def test_missing_blob_handling(redis_url: str):
219+
"""Test that implementations handle missing blobs gracefully."""
220+
async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
221+
thread_id = "test-missing-blob"
222+
checkpoint_ns = "test-namespace"
223+
parent_checkpoint_id = "parent-checkpoint-missing"
224+
225+
# Directly insert a write with missing blob field
226+
write_key = f"checkpoint_write:{thread_id}:{checkpoint_ns}:{parent_checkpoint_id}:task-1:0"
227+
write_data = {
228+
"thread_id": thread_id,
229+
"checkpoint_ns": checkpoint_ns,
230+
"checkpoint_id": parent_checkpoint_id,
231+
"task_id": "task-1",
232+
"idx": 0,
233+
"channel": TASKS,
234+
"type": "json",
235+
# No blob field - this should be handled gracefully
236+
}
237+
238+
# Insert directly into Redis
239+
client = RedisConnectionFactory.get_redis_connection(redis_url)
240+
client.json().set(write_key, "$", write_data)
241+
client.close()
242+
243+
# Load pending sends - should handle missing blob
244+
pending_sends = await saver._aload_pending_sends(
245+
thread_id=thread_id,
246+
checkpoint_ns=checkpoint_ns,
247+
parent_checkpoint_id=parent_checkpoint_id,
248+
)
249+
250+
# Should return empty list since blob is missing
251+
assert len(pending_sends) == 0
252+
253+
254+
def test_all_implementations_consistent(redis_url: str):
255+
"""Verify all 4 implementations produce consistent results."""
256+
thread_id = "test-consistency"
257+
checkpoint_ns = "test-namespace"
258+
parent_checkpoint_id = "parent-checkpoint-consist"
259+
260+
# Create the same test data for all implementations
261+
hil_writes = create_hil_task_writes()
262+
263+
results = []
264+
265+
# Test sync implementation
266+
with RedisSaver.from_conn_string(redis_url) as saver:
267+
saver.setup()
268+
config: RunnableConfig = {
269+
"configurable": {
270+
"thread_id": thread_id,
271+
"checkpoint_ns": checkpoint_ns,
272+
"checkpoint_id": parent_checkpoint_id,
273+
}
274+
}
275+
checkpoint = create_test_checkpoint()
276+
checkpoint["id"] = parent_checkpoint_id
277+
metadata: CheckpointMetadata = {"source": "input", "step": 1}
278+
279+
saver.put(config, checkpoint, metadata, {})
280+
saver.put_writes(config, hil_writes, task_id="consist-task")
281+
282+
pending_sends = saver._load_pending_sends(
283+
thread_id=thread_id,
284+
checkpoint_ns=checkpoint_ns,
285+
parent_checkpoint_id=parent_checkpoint_id,
286+
)
287+
results.append(("sync", pending_sends))
288+
289+
# Verify all implementations return the same number of results
290+
# and all results have the expected structure
291+
for name, sends in results:
292+
assert len(sends) == 3, f"{name} returned {len(sends)} sends, expected 3"
293+
for type_bytes, blob_bytes in sends:
294+
assert isinstance(type_bytes, bytes), f"{name}: type not bytes"
295+
# Blob can be bytes or str depending on how Redis stores it
296+
assert isinstance(blob_bytes, (bytes, str)), f"{name}: blob not bytes or str"
297+
assert len(type_bytes) > 0, f"{name}: empty type"
298+
assert len(blob_bytes) > 0, f"{name}: empty blob"

0 commit comments

Comments
 (0)