Skip to content

Commit 19f2f22

Browse files
committed
Move sentinel replacement to serializer
1 parent f614c5f commit 19f2f22

File tree

2 files changed

+89
-27
lines changed

2 files changed

+89
-27
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,19 @@ def list(
7878
# Construct the filter expression
7979
filter_expression = []
8080
if config:
81-
filter_expression.append(
82-
Tag("thread_id") == config["configurable"]["thread_id"]
83-
)
84-
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
81+
thread_id = config["configurable"]["thread_id"]
82+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
83+
checkpoint_id = get_checkpoint_id(config)
84+
filter_expression.append(Tag("thread_id") == thread_id)
85+
86+
# Following the Postgres implementation, we only want to filter by
87+
# checkpoint_ns if it's set. This is slightly different than the
88+
# get_tuple() logic, where we always query for checkpoint_ns.
89+
if checkpoint_ns:
8590
filter_expression.append(Tag("checkpoint_ns") == checkpoint_ns)
86-
if checkpoint_id := get_checkpoint_id(config):
87-
filter_expression.append(Tag("checkpoint_id") == checkpoint_id)
91+
# We want to find all checkpoints for the thread matching the other
92+
# filters, with any checkpoint_id.
93+
filter_expression.append(Tag("checkpoint_id") == checkpoint_id)
8894

8995
if filter:
9096
for k, v in filter.items():
@@ -195,9 +201,9 @@ def put(
195201
"""Store a checkpoint to Redis."""
196202
configurable = config["configurable"].copy()
197203
thread_id = configurable.pop("thread_id")
198-
checkpoint_ns = configurable.pop("checkpoint_ns")
204+
checkpoint_ns = configurable.pop("checkpoint_ns", "")
199205
checkpoint_id = configurable.pop(
200-
"checkpoint_id", configurable.pop("thread_ts", None)
206+
"checkpoint_id", configurable.pop("thread_ts", "")
201207
)
202208

203209
copy = checkpoint.copy()
@@ -211,10 +217,10 @@ def put(
211217

212218
# Store checkpoint data
213219
checkpoint_data = {
214-
"thread_id": thread_id,
215-
"checkpoint_ns": checkpoint_ns,
216-
"checkpoint_id": checkpoint["id"],
217-
"parent_checkpoint_id": checkpoint_id,
220+
"thread_id": thread_id or "",
221+
"checkpoint_ns": checkpoint_ns or "",
222+
"checkpoint_id": checkpoint["id"] or "",
223+
"parent_checkpoint_id": checkpoint_id or "",
218224
"checkpoint": self._dump_checkpoint(copy),
219225
"metadata": self._dump_metadata(metadata),
220226
}
@@ -261,11 +267,16 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
261267
checkpoint_id = get_checkpoint_id(config)
262268
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
263269

264-
checkpoint_filter_expression = Tag("thread_id") == thread_id
265270
if checkpoint_id:
266-
checkpoint_filter_expression &= Tag("checkpoint_id") == str(checkpoint_id)
267-
if checkpoint_ns:
268-
checkpoint_filter_expression &= Tag("checkpoint_ns") == checkpoint_ns
271+
checkpoint_filter_expression = (
272+
(Tag("thread_id") == thread_id)
273+
& (Tag("checkpoint_ns") == checkpoint_ns)
274+
& (Tag("checkpoint_id") == str(checkpoint_id))
275+
)
276+
else:
277+
checkpoint_filter_expression = (Tag("thread_id") == thread_id) & (
278+
Tag("checkpoint_ns") == checkpoint_ns
279+
)
269280

270281
# Construct the query
271282
checkpoints_query = FilterQuery(
@@ -289,20 +300,25 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
289300

290301
doc = results.docs[0]
291302

303+
doc_thread_id = doc["thread_id"]
304+
doc_checkpoint_ns = doc["checkpoint_ns"]
305+
doc_checkpoint_id = doc["checkpoint_id"]
306+
doc_parent_checkpoint_id = doc["parent_checkpoint_id"]
307+
292308
# Fetch channel_values
293309
channel_values = self.get_channel_values(
294-
thread_id=doc["thread_id"],
295-
checkpoint_ns=doc["checkpoint_ns"],
296-
checkpoint_id=doc["checkpoint_id"],
310+
thread_id=doc_thread_id,
311+
checkpoint_ns=doc_checkpoint_ns,
312+
checkpoint_id=doc_checkpoint_id,
297313
)
298314

299315
# Fetch pending_sends from parent checkpoint
300316
pending_sends = []
301-
if doc["parent_checkpoint_id"]:
317+
if doc_parent_checkpoint_id:
302318
pending_sends = self._load_pending_sends(
303-
thread_id=thread_id,
304-
checkpoint_ns=checkpoint_ns,
305-
parent_checkpoint_id=doc["parent_checkpoint_id"],
319+
thread_id=doc_thread_id,
320+
checkpoint_ns=doc_checkpoint_ns,
321+
parent_checkpoint_id=doc_parent_checkpoint_id,
306322
)
307323

308324
# Fetch and parse metadata
@@ -324,7 +340,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
324340
"configurable": {
325341
"thread_id": thread_id,
326342
"checkpoint_ns": checkpoint_ns,
327-
"checkpoint_id": doc["checkpoint_id"],
343+
"checkpoint_id": doc_checkpoint_id,
328344
}
329345
}
330346

@@ -335,7 +351,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
335351
)
336352

337353
pending_writes = self._load_pending_writes(
338-
thread_id, checkpoint_ns, checkpoint_id
354+
thread_id, checkpoint_ns, doc_checkpoint_id
339355
)
340356

341357
return CheckpointTuple(

langgraph/checkpoint/redis/jsonplus_redis.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,46 @@
11
import base64
2+
import logging
23
from typing import Any, Union
34

45
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
56

67

8+
logger = logging.getLogger(__name__)
9+
10+
11+
# RediSearch versions below 2.10 don't support indexing and querying
12+
# empty strings, so we use a sentinel value to represent empty strings.
13+
EMPTY_STRING_SENTINEL = "__empty__"
14+
15+
716
class JsonPlusRedisSerializer(JsonPlusSerializer):
817
"""Redis-optimized serializer that stores strings directly."""
918

19+
SENTINEL_FIELDS = [
20+
"thread_id",
21+
"checkpoint_id",
22+
"checkpoint_ns",
23+
"parent_checkpoint_id",
24+
]
25+
1026
def dumps_typed(self, obj: Any) -> tuple[str, str]: # type: ignore[override]
1127
if isinstance(obj, (bytes, bytearray)):
1228
return "base64", base64.b64encode(obj).decode("utf-8")
1329
else:
14-
return "json", self.dumps(obj).decode("utf-8")
30+
for field in self.SENTINEL_FIELDS:
31+
try:
32+
if field in obj and not obj[field]:
33+
obj[field] = EMPTY_STRING_SENTINEL
34+
except (KeyError, AttributeError):
35+
try:
36+
if hasattr(obj, field) and not getattr(obj, field, None):
37+
setattr(obj, field, EMPTY_STRING_SENTINEL)
38+
except Exception as e:
39+
logger.debug(
40+
f"Error setting {field} from empty string to sentinel: {e}"
41+
)
42+
results = self.dumps(obj).decode("utf-8")
43+
return "json", results
1544

1645
def loads_typed(self, data: tuple[str, Union[str, bytes]]) -> Any:
1746
type_, data_ = data
@@ -22,4 +51,21 @@ def loads_typed(self, data: tuple[str, Union[str, bytes]]) -> Any:
2251
return decoded
2352
elif type_ == "json":
2453
data_bytes = data_ if isinstance(data_, bytes) else data_.encode()
25-
return self.loads(data_bytes)
54+
results = self.loads(data_bytes)
55+
for field in self.SENTINEL_FIELDS:
56+
try:
57+
if field in results and results[field] == EMPTY_STRING_SENTINEL:
58+
results[field] = ""
59+
except (KeyError, AttributeError):
60+
try:
61+
if (
62+
hasattr(results, field)
63+
and getattr(results, field) == EMPTY_STRING_SENTINEL
64+
):
65+
setattr(results, field, "")
66+
except Exception as e:
67+
logger.debug(
68+
f"Error setting {field} from sentinel to empty string: {e}"
69+
)
70+
pass
71+
return results

0 commit comments

Comments
 (0)