|
5 | 5 | import asyncio |
6 | 6 | import json |
7 | 7 | from contextlib import asynccontextmanager |
8 | | -from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast |
| 8 | +from functools import partial |
| 9 | +from types import TracebackType |
| 10 | +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast |
9 | 11 |
|
10 | 12 | from langchain_core.runnables import RunnableConfig |
11 | 13 | from redisvl.index import AsyncSearchIndex |
|
30 | 32 | ) |
31 | 33 | from langgraph.constants import TASKS |
32 | 34 | from redis.asyncio import Redis as AsyncRedis |
| 35 | +from redis.asyncio.client import Pipeline |
33 | 36 |
|
34 | 37 | SCHEMAS = [ |
35 | 38 | { |
|
77 | 80 | ] |
78 | 81 |
|
79 | 82 |
|
| 83 | +# func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], |
| 84 | +async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None: |
| 85 | + exists: int = await pipe.exists(key) |
| 86 | + if exists: |
| 87 | + await pipe.json().set(key, "$.channel", write_obj["channel"]) |
| 88 | + await pipe.json().set(key, "$.type", write_obj["type"]) |
| 89 | + await pipe.json().set(key, "$.blob", write_obj["blob"]) |
| 90 | + else: |
| 91 | + await pipe.json().set(key, "$", write_obj) |
| 92 | + |
| 93 | + |
80 | 94 | class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]): |
81 | 95 | """Async Redis implementation that only stores the most recent checkpoint.""" |
82 | 96 |
|
@@ -104,7 +118,12 @@ def __init__( |
104 | 118 | async def __aenter__(self) -> AsyncShallowRedisSaver: |
105 | 119 | return self |
106 | 120 |
|
107 | | - async def __aexit__(self, exc_type, exc, tb) -> None: |
| 121 | + async def __aexit__( |
| 122 | + self, |
| 123 | + exc_type: Optional[Type[BaseException]], |
| 124 | + exc: Optional[BaseException], |
| 125 | + tb: Optional[TracebackType], |
| 126 | + ) -> None: |
108 | 127 | if self._owns_its_client: |
109 | 128 | await self._redis.aclose() # type: ignore[attr-defined] |
110 | 129 | await self._redis.connection_pool.disconnect() |
@@ -403,18 +422,7 @@ async def aput_writes( |
403 | 422 | write_obj["idx"], |
404 | 423 | ) |
405 | 424 | if upsert_case: |
406 | | - |
407 | | - async def tx(pipe, key=key, write_obj=write_obj): |
408 | | - exists = await pipe.exists(key) |
409 | | - if exists: |
410 | | - await pipe.json().set( |
411 | | - key, "$.channel", write_obj["channel"] |
412 | | - ) |
413 | | - await pipe.json().set(key, "$.type", write_obj["type"]) |
414 | | - await pipe.json().set(key, "$.blob", write_obj["blob"]) |
415 | | - else: |
416 | | - await pipe.json().set(key, "$", write_obj) |
417 | | - |
| 425 | + tx = partial(_write_obj_tx, key=key, write_obj=write_obj) |
418 | 426 | await self._redis.transaction(tx, key) |
419 | 427 | else: |
420 | 428 | # Unlike AsyncRedisSaver, the shallow implementation always overwrites |
|
0 commit comments