Skip to content

Commit e9f9e17

Browse files
committed
Merge remote-tracking branch 'origin/main' into feat/RAAE-1287/checkpoint-custom-prefix
2 parents 827c6d9 + 5ce5acd commit e9f9e17

File tree

8 files changed

+897
-21
lines changed

8 files changed

+897
-21
lines changed

langgraph/store/redis/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _batch_get_ops(
257257
for idx, key in items:
258258
if key in key_to_row:
259259
data, doc_id = key_to_row[key]
260-
results[idx] = _row_to_item(namespace, data)
260+
results[idx] = _row_to_item(
261+
namespace, data, deserialize_fn=self._deserialize_value
262+
)
261263

262264
# Find the corresponding operation by looking it up in the operation list
263265
# This is needed because idx is the index in the overall operation list
@@ -587,6 +589,7 @@ def _batch_search_ops(
587589
_decode_ns(store_doc["prefix"]),
588590
store_doc,
589591
score=score,
592+
deserialize_fn=self._deserialize_value,
590593
)
591594
)
592595

@@ -653,7 +656,13 @@ def _batch_search_ops(
653656
)
654657
refresh_keys.append(vector_key)
655658

656-
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
659+
items.append(
660+
_row_to_search_item(
661+
_decode_ns(data["prefix"]),
662+
data,
663+
deserialize_fn=self._deserialize_value,
664+
)
665+
)
657666

658667
# Refresh TTL if requested
659668
if op.refresh_ttl and refresh_keys and self.ttl_config:

langgraph/store/redis/aio.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,9 @@ async def _batch_get_ops(
472472
for idx, key in items:
473473
if key in key_to_row:
474474
data, doc_id = key_to_row[key]
475-
results[idx] = _row_to_item(namespace, data)
475+
results[idx] = _row_to_item(
476+
namespace, data, deserialize_fn=self._deserialize_value
477+
)
476478

477479
# Find the corresponding operation by looking it up in the operation list
478480
# This is needed because idx is the index in the overall operation list
@@ -578,7 +580,7 @@ async def _aprepare_batch_PUT_queries(
578580
doc = RedisDocument(
579581
prefix=_namespace_to_text(op.namespace),
580582
key=op.key,
581-
value=op.value,
583+
value=self._serialize_value(op.value),
582584
created_at=now,
583585
updated_at=now,
584586
ttl_minutes=ttl_minutes,
@@ -872,6 +874,7 @@ async def _batch_search_ops(
872874
_decode_ns(store_doc["prefix"]),
873875
store_doc,
874876
score=score,
877+
deserialize_fn=self._deserialize_value,
875878
)
876879
)
877880

@@ -939,7 +942,13 @@ async def _batch_search_ops(
939942
)
940943
refresh_keys.append(vector_key)
941944

942-
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
945+
items.append(
946+
_row_to_search_item(
947+
_decode_ns(data["prefix"]),
948+
data,
949+
deserialize_fn=self._deserialize_value,
950+
)
951+
)
943952

944953
# Refresh TTL if requested
945954
if op.refresh_ttl and refresh_keys and self.ttl_config:

langgraph/store/redis/base.py

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from __future__ import annotations
44

55
import copy
6+
import json
67
import logging
78
import threading
89
from collections import defaultdict
910
from datetime import datetime, timedelta, timezone
1011
from typing import (
1112
Any,
13+
Callable,
1214
Dict,
1315
Generic,
1416
Iterable,
@@ -40,6 +42,8 @@
4042
from redisvl.query.filter import Tag, Text
4143
from redisvl.utils.token_escaper import TokenEscaper
4244

45+
from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer
46+
4347
from .token_unescaper import TokenUnescaper
4448
from .types import IndexType, RedisClientType
4549

@@ -124,6 +128,9 @@ class BaseRedisStore(Generic[RedisClientType, IndexType]):
124128
supports_ttl: bool = True
125129
ttl_config: Optional[TTLConfig] = None
126130

131+
# Serializer for handling complex objects like LangChain messages
132+
_serde: JsonPlusRedisSerializer
133+
127134
def _apply_ttl_to_keys(
128135
self,
129136
main_key: str,
@@ -223,6 +230,8 @@ def __init__(
223230
self._redis = conn
224231
# Store cluster_mode; None means auto-detect in RedisStore or AsyncRedisStore
225232
self.cluster_mode = cluster_mode
233+
# Initialize the serializer for handling complex objects like LangChain messages
234+
self._serde = JsonPlusRedisSerializer()
226235

227236
# Store custom prefixes
228237
self.store_prefix = store_prefix
@@ -357,6 +366,109 @@ async def aset_client_info(self) -> None:
357366
# Silently fail if even echo doesn't work
358367
pass
359368

369+
def _serialize_value(self, value: Any) -> Any:
370+
"""Serialize a value for storage in Redis.
371+
372+
This method handles complex objects like LangChain messages by
373+
serializing them to a JSON-compatible format.
374+
375+
The method is smart about serialization:
376+
- If the value is a simple JSON-serializable dict/list, it's stored as-is
377+
- If the value contains complex objects (HumanMessage, etc.), it uses
378+
the serde wrapper format with __serde_type__ and __serde_data__ keys
379+
380+
Note: Values containing LangChain messages will be wrapped in a serde format,
381+
which means filters on nested fields won't work for such values.
382+
383+
Args:
384+
value: The value to serialize (can contain HumanMessage, AIMessage, etc.)
385+
386+
Returns:
387+
A JSON-serializable representation of the value
388+
"""
389+
if value is None:
390+
return None
391+
392+
# First, try standard JSON serialization to check if it's needed
393+
try:
394+
json.dumps(value)
395+
# Value is already JSON-serializable, return as-is for backward
396+
# compatibility and to preserve filter functionality
397+
return value
398+
except TypeError:
399+
# Value contains non-JSON-serializable objects, use serde wrapper
400+
pass
401+
402+
# Use the serializer to handle complex objects
403+
type_str, data_bytes = self._serde.dumps_typed(value)
404+
# Store the serialized data with type info for proper deserialization
405+
# Handle different type formats explicitly for clarity
406+
if type_str == "json":
407+
data_encoded = data_bytes.decode("utf-8")
408+
else:
409+
# bytes, bytearray, msgpack, and other types are hex-encoded
410+
data_encoded = data_bytes.hex()
411+
412+
return {
413+
"__serde_type__": type_str,
414+
"__serde_data__": data_encoded,
415+
}
416+
417+
def _deserialize_value(self, value: Any) -> Any:
418+
"""Deserialize a value from Redis storage.
419+
420+
This method handles both new serialized format and legacy plain values
421+
for backward compatibility.
422+
423+
Args:
424+
value: The value from Redis (may be serialized or plain)
425+
426+
Returns:
427+
The deserialized value with proper Python objects (HumanMessage, etc.)
428+
"""
429+
if value is None:
430+
return None
431+
432+
# Check if this is a serialized value (new format)
433+
# Use exact key check to prevent collisions with user data
434+
if isinstance(value, dict) and set(value.keys()) == {
435+
"__serde_type__",
436+
"__serde_data__",
437+
}:
438+
type_str = value["__serde_type__"]
439+
data_str = value["__serde_data__"]
440+
441+
try:
442+
# Convert back to bytes based on type
443+
if type_str == "json":
444+
data_bytes = data_str.encode("utf-8")
445+
else:
446+
# bytes, bytearray, msgpack types are hex-encoded
447+
data_bytes = bytes.fromhex(data_str)
448+
449+
return self._serde.loads_typed((type_str, data_bytes))
450+
except (ValueError, TypeError) as e:
451+
# Handle hex decoding errors or deserialization failures
452+
logger.error(
453+
"Failed to deserialize value from Redis: type=%r, error=%s",
454+
type_str,
455+
e,
456+
)
457+
# Return None to indicate deserialization failure
458+
return None
459+
except Exception as e:
460+
# Handle any other unexpected errors during deserialization
461+
logger.error(
462+
"Unexpected error deserializing value from Redis: type=%r, error=%s",
463+
type_str,
464+
e,
465+
)
466+
return None
467+
468+
# Legacy format: value is stored as-is (plain JSON-serializable data)
469+
# Return as-is for backward compatibility
470+
return value
471+
360472
def _get_batch_GET_ops_queries(
361473
self,
362474
get_ops: Sequence[tuple[int, GetOp]],
@@ -433,7 +545,7 @@ def _prepare_batch_PUT_queries(
433545
doc = RedisDocument(
434546
prefix=_namespace_to_text(op.namespace),
435547
key=op.key,
436-
value=op.value,
548+
value=self._serialize_value(op.value),
437549
created_at=now,
438550
updated_at=now,
439551
ttl_minutes=ttl_minutes,
@@ -568,10 +680,27 @@ def _decode_ns(ns: str) -> tuple[str, ...]:
568680
return tuple(_token_unescaper.unescape(ns).split("."))
569681

570682

571-
def _row_to_item(namespace: tuple[str, ...], row: dict[str, Any]) -> Item:
572-
"""Convert a row from Redis to an Item."""
683+
def _row_to_item(
684+
namespace: tuple[str, ...],
685+
row: dict[str, Any],
686+
deserialize_fn: Optional[Callable[[Any], Any]] = None,
687+
) -> Item:
688+
"""Convert a row from Redis to an Item.
689+
690+
Args:
691+
namespace: The namespace tuple for this item
692+
row: The raw row data from Redis
693+
deserialize_fn: Optional function to deserialize the value (handles
694+
LangChain messages, etc.)
695+
696+
Returns:
697+
An Item with properly deserialized value
698+
"""
699+
value = row["value"]
700+
if deserialize_fn is not None:
701+
value = deserialize_fn(value)
573702
return Item(
574-
value=row["value"],
703+
value=value,
575704
key=row["key"],
576705
namespace=namespace,
577706
created_at=datetime.fromtimestamp(row["created_at"] / 1_000_000, timezone.utc),
@@ -583,10 +712,25 @@ def _row_to_search_item(
583712
namespace: tuple[str, ...],
584713
row: dict[str, Any],
585714
score: Optional[float] = None,
715+
deserialize_fn: Optional[Callable[[Any], Any]] = None,
586716
) -> SearchItem:
587-
"""Convert a row from Redis to a SearchItem."""
717+
"""Convert a row from Redis to a SearchItem.
718+
719+
Args:
720+
namespace: The namespace tuple for this item
721+
row: The raw row data from Redis
722+
score: Optional similarity score from vector search
723+
deserialize_fn: Optional function to deserialize the value (handles
724+
LangChain messages, etc.)
725+
726+
Returns:
727+
A SearchItem with properly deserialized value
728+
"""
729+
value = row["value"]
730+
if deserialize_fn is not None:
731+
value = deserialize_fn(value)
588732
return SearchItem(
589-
value=row["value"],
733+
value=value,
590734
key=row["key"],
591735
namespace=namespace,
592736
created_at=datetime.fromtimestamp(row["created_at"] / 1_000_000, timezone.utc),

poetry.lock

Lines changed: 42 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ packages = [{ include = "langgraph" }]
2020
python = ">=3.10,<3.14"
2121
langgraph-checkpoint = ">=3.0.0,<4.0.0"
2222
redisvl = ">=0.11.0,<1.0.0"
23-
redis = ">=5.2.1,<7.0.0"
23+
redis = ">=5.2.1"
2424
orjson = "^3.9.0"
2525
tomli = { version = "^2.0.1", python = "<3.11" }
2626

0 commit comments

Comments
 (0)