Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

name: Lint

on:
pull_request:
push:
branches:
- main

env:
POETRY_VERSION: "1.8.3"

jobs:
check:
name: Style-check ${{ matrix.python-version }}
runs-on: ubuntu-latest
strategy:
matrix:
# Only lint on the min and max supported Python versions.
# It's extremely unlikely that there's a lint issue on any version in between
# that doesn't show up on the min or max versions.
#
# GitHub rate-limits how many jobs can be running at any one time.
# Starting new jobs is also relatively slow,
# so linting on fewer versions makes CI faster.
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: ${{ env.POETRY_VERSION }}
- name: Install dependencies
run: |
poetry install --all-extras
- name: run lint
run: |
make lint
59 changes: 59 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: Test Suite

on:
pull_request:

push:
branches:
- main

env:
POETRY_VERSION: "1.8.3"

jobs:
test:
name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis-stack ${{matrix.redis-stack-version}}]
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version: [3.9, '3.10', 3.11, 3.12]
connection: ['hiredis', 'plain']
redis-stack-version: ['6.2.6-v9', 'latest', 'edge']

services:
redis:
image: redis/redis-stack-server:${{matrix.redis-stack-version}}
ports:
- 6379:6379

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: ${{ env.POETRY_VERSION }}

- name: Install dependencies
run: |
poetry install --all-extras

- name: Install hiredis if needed
if: matrix.connection == 'hiredis'
run: |
poetry add hiredis

- name: Set Redis version
run: |
echo "REDIS_VERSION=${{ matrix.redis-stack-version }}" >> $GITHUB_ENV

- name: Run tests
run: |
make ci_test
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
######################

test:
poetry run pytest tests
poetry run pytest tests --run-api-tests

test_watch:
poetry run ptw .

ci_test:
poetry run pytest tests

######################
# LINTING AND FORMATTING
Expand All @@ -32,4 +35,4 @@ lint lint_diff lint_package lint_tests:

format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)
101 changes: 49 additions & 52 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import partial
from types import TracebackType
from typing import Any, List, Optional, Sequence, Tuple, Type, cast

Expand All @@ -27,6 +28,26 @@
from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.constants import TASKS
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.client import Pipeline


async def _write_obj_tx(
pipe: Pipeline,
key: str,
write_obj: dict[str, Any],
upsert_case: bool,
) -> None:
exists: int = await pipe.exists(key)
if upsert_case:
if exists:
await pipe.json().set(key, "$.channel", write_obj["channel"])
await pipe.json().set(key, "$.type", write_obj["type"])
await pipe.json().set(key, "$.blob", write_obj["blob"])
else:
await pipe.json().set(key, "$", write_obj)
else:
if not exists:
await pipe.json().set(key, "$", write_obj)


class AsyncRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
Expand Down Expand Up @@ -73,6 +94,7 @@ def create_indexes(self) -> None:

async def __aenter__(self) -> AsyncRedisSaver:
"""Async context manager enter."""
await self.asetup()
return self

async def __aexit__(
Expand All @@ -82,15 +104,15 @@ async def __aexit__(
exc_tb: Optional[TracebackType],
) -> None:
"""Async context manager exit."""
# Close client connections
if hasattr(self, "checkpoint_index") and hasattr(
self.checkpoint_index, "client"
):
await self.checkpoint_index.client.aclose()
if hasattr(self, "channel_index") and hasattr(self.channel_index, "client"):
await self.channel_index.client.aclose()
if hasattr(self, "writes_index") and hasattr(self.writes_index, "client"):
await self.writes_index.client.aclose()
if self._owns_its_client:
await self._redis.aclose() # type: ignore[attr-defined]
await self._redis.connection_pool.disconnect()

# Prevent RedisVL from attempting to close the client
# on an event loop in a separate thread.
self.checkpoints_index._redis_client = None
self.checkpoint_blobs_index._redis_client = None
self.checkpoint_writes_index._redis_client = None

async def asetup(self) -> None:
"""Initialize Redis indexes asynchronously."""
Expand Down Expand Up @@ -418,38 +440,19 @@ async def aput_writes(
}
writes_objects.append(write_obj)

# For each write, check existence and then perform appropriate operation
async with self.checkpoints_index.client.pipeline(
transaction=False
) as pipeline:
for write_obj in writes_objects:
key = self._make_redis_checkpoint_writes_key(
thread_id,
checkpoint_ns,
checkpoint_id,
task_id,
write_obj["idx"],
)

# First check if key exists
key_exists = await self._redis.exists(key) == 1

if all(w[0] in WRITES_IDX_MAP for w in writes):
# UPSERT case - only update specific fields
if key_exists:
# Update only channel, type, and blob fields
pipeline.json().set(key, "$.channel", write_obj["channel"])
pipeline.json().set(key, "$.type", write_obj["type"])
pipeline.json().set(key, "$.blob", write_obj["blob"])
else:
# For new records, set the complete object
pipeline.json().set(key, "$", write_obj)
else:
# INSERT case - only insert if doesn't exist
if not key_exists:
pipeline.json().set(key, "$", write_obj)

await pipeline.execute()
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
for write_obj in writes_objects:
key = self._make_redis_checkpoint_writes_key(
thread_id,
checkpoint_ns,
checkpoint_id,
task_id,
write_obj["idx"],
)
tx = partial(
_write_obj_tx, key=key, write_obj=write_obj, upsert_case=upsert_case
)
await self._redis.transaction(tx, key)

def put_writes(
self,
Expand Down Expand Up @@ -542,18 +545,12 @@ async def from_conn_string(
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
) -> AsyncIterator[AsyncRedisSaver]:
saver: Optional[AsyncRedisSaver] = None
try:
saver = cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
)
async with cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
) as saver:
yield saver
finally:
if saver and saver._owns_its_client: # Ensure saver is not None
await saver._redis.aclose() # type: ignore[attr-defined]
await saver._redis.connection_pool.disconnect()

async def aget_channel_values(
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
Expand Down
Loading