diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..d8d7b54 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,47 @@ + +name: Lint + +on: + pull_request: + push: + branches: + - main + +env: + POETRY_VERSION: "1.8.3" + +jobs: + check: + name: Style-check ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + matrix: + # Only lint on the min and max supported Python versions. + # It's extremely unlikely that there's a lint issue on any version in between + # that doesn't show up on the min or max versions. + # + # GitHub rate-limits how many jobs can be running at any one time. + # Starting new jobs is also relatively slow, + # so linting on fewer versions makes CI faster. + python-version: + - "3.9" + - "3.10" + - "3.11" + - "3.12" + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: ${{ env.POETRY_VERSION }} + - name: Install dependencies + run: | + poetry install --all-extras + - name: run lint + run: | + make lint diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a71bb72 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,59 @@ +name: Test Suite + +on: + pull_request: + + push: + branches: + - main + +env: + POETRY_VERSION: "1.8.3" + +jobs: + test: + name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis-stack ${{matrix.redis-stack-version}}] + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: [3.9, '3.10', 3.11, 3.12] + connection: ['hiredis', 'plain'] + redis-stack-version: ['6.2.6-v9', 'latest', 'edge'] + + services: + redis: + image: redis/redis-stack-server:${{matrix.redis-stack-version}} + ports: + - 6379:6379 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: ${{ env.POETRY_VERSION }} + + - name: Install dependencies + run: | + poetry install --all-extras + + - name: Install hiredis if needed + if: matrix.connection == 'hiredis' + run: | + poetry add hiredis + + - name: Set Redis version + run: | + echo "REDIS_VERSION=${{ matrix.redis-stack-version }}" >> $GITHUB_ENV + + - name: Run tests + run: | + make ci_test diff --git a/Makefile b/Makefile index 87a5921..69d1a8e 100644 --- a/Makefile +++ b/Makefile @@ -5,10 +5,13 @@ ###################### test: - poetry run pytest tests + poetry run pytest tests --run-api-tests test_watch: poetry run ptw . + +ci_test: + poetry run pytest tests ###################### # LINTING AND FORMATTING @@ -32,4 +35,4 @@ lint lint_diff lint_package lint_tests: format format_diff: poetry run ruff format $(PYTHON_FILES) - poetry run ruff check --select I --fix $(PYTHON_FILES) \ No newline at end of file + poetry run ruff check --select I --fix $(PYTHON_FILES) diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 8554645..03481d1 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -6,6 +6,7 @@ import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from functools import partial from types import TracebackType from typing import Any, List, Optional, Sequence, Tuple, Type, cast @@ -27,6 +28,26 @@ from langgraph.checkpoint.redis.base import BaseRedisSaver from langgraph.constants import TASKS from redis.asyncio import Redis as AsyncRedis +from redis.asyncio.client import Pipeline + + +async def _write_obj_tx( + pipe: Pipeline, + key: str, + write_obj: dict[str, Any], + upsert_case: bool, +) -> None: + exists: int = await pipe.exists(key) + if upsert_case: + if exists: + await pipe.json().set(key, "$.channel", write_obj["channel"]) + await pipe.json().set(key, "$.type", write_obj["type"]) + await pipe.json().set(key, "$.blob", write_obj["blob"]) + else: + await pipe.json().set(key, "$", write_obj) + else: + if not exists: + await pipe.json().set(key, "$", write_obj) class AsyncRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]): @@ -73,6 +94,7 @@ def create_indexes(self) -> None: async def __aenter__(self) -> AsyncRedisSaver: """Async context manager enter.""" + await self.asetup() return self async def __aexit__( @@ -82,15 +104,15 @@ async def __aexit__( exc_tb: Optional[TracebackType], ) -> None: """Async context manager exit.""" - # Close client connections - if hasattr(self, "checkpoint_index") and hasattr( - self.checkpoint_index, "client" - ): - await self.checkpoint_index.client.aclose() - if hasattr(self, "channel_index") and hasattr(self.channel_index, "client"): - await self.channel_index.client.aclose() - if hasattr(self, "writes_index") and hasattr(self.writes_index, "client"): - await self.writes_index.client.aclose() + if self._owns_its_client: + await self._redis.aclose() # type: ignore[attr-defined] + await self._redis.connection_pool.disconnect() + + # Prevent RedisVL from attempting to close the client + # on an event loop in a separate thread. + self.checkpoints_index._redis_client = None + self.checkpoint_blobs_index._redis_client = None + self.checkpoint_writes_index._redis_client = None async def asetup(self) -> None: """Initialize Redis indexes asynchronously.""" @@ -418,38 +440,19 @@ async def aput_writes( } writes_objects.append(write_obj) - # For each write, check existence and then perform appropriate operation - async with self.checkpoints_index.client.pipeline( - transaction=False - ) as pipeline: - for write_obj in writes_objects: - key = self._make_redis_checkpoint_writes_key( - thread_id, - checkpoint_ns, - checkpoint_id, - task_id, - write_obj["idx"], - ) - - # First check if key exists - key_exists = await self._redis.exists(key) == 1 - - if all(w[0] in WRITES_IDX_MAP for w in writes): - # UPSERT case - only update specific fields - if key_exists: - # Update only channel, type, and blob fields - pipeline.json().set(key, "$.channel", write_obj["channel"]) - pipeline.json().set(key, "$.type", write_obj["type"]) - pipeline.json().set(key, "$.blob", write_obj["blob"]) - else: - # For new records, set the complete object - pipeline.json().set(key, "$", write_obj) - else: - # INSERT case - only insert if doesn't exist - if not key_exists: - pipeline.json().set(key, "$", write_obj) - - await pipeline.execute() + upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes) + for write_obj in writes_objects: + key = self._make_redis_checkpoint_writes_key( + thread_id, + checkpoint_ns, + checkpoint_id, + task_id, + write_obj["idx"], + ) + tx = partial( + _write_obj_tx, key=key, write_obj=write_obj, upsert_case=upsert_case + ) + await self._redis.transaction(tx, key) def put_writes( self, @@ -542,18 +545,12 @@ async def from_conn_string( redis_client: Optional[AsyncRedis] = None, connection_args: Optional[dict[str, Any]] = None, ) -> AsyncIterator[AsyncRedisSaver]: - saver: Optional[AsyncRedisSaver] = None - try: - saver = cls( - redis_url=redis_url, - redis_client=redis_client, - connection_args=connection_args, - ) + async with cls( + redis_url=redis_url, + redis_client=redis_client, + connection_args=connection_args, + ) as saver: yield saver - finally: - if saver and saver._owns_its_client: # Ensure saver is not None - await saver._redis.aclose() # type: ignore[attr-defined] - await saver._redis.connection_pool.disconnect() async def aget_channel_values( self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = "" diff --git a/langgraph/checkpoint/redis/ashallow.py b/langgraph/checkpoint/redis/ashallow.py index be25ea8..16ff6c6 100644 --- a/langgraph/checkpoint/redis/ashallow.py +++ b/langgraph/checkpoint/redis/ashallow.py @@ -5,7 +5,9 @@ import asyncio import json from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, cast +from functools import partial +from types import TracebackType +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast from langchain_core.runnables import RunnableConfig from redisvl.index import AsyncSearchIndex @@ -30,6 +32,7 @@ ) from langgraph.constants import TASKS from redis.asyncio import Redis as AsyncRedis +from redis.asyncio.client import Pipeline SCHEMAS = [ { @@ -77,6 +80,17 @@ ] +# func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], +async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None: + exists: int = await pipe.exists(key) + if exists: + await pipe.json().set(key, "$.channel", write_obj["channel"]) + await pipe.json().set(key, "$.type", write_obj["type"]) + await pipe.json().set(key, "$.blob", write_obj["blob"]) + else: + await pipe.json().set(key, "$", write_obj) + + class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]): """Async Redis implementation that only stores the most recent checkpoint.""" @@ -99,9 +113,27 @@ def __init__( redis_client=redis_client, connection_args=connection_args, ) - # self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() + async def __aenter__(self) -> AsyncShallowRedisSaver: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + if self._owns_its_client: + await self._redis.aclose() # type: ignore[attr-defined] + await self._redis.connection_pool.disconnect() + + # Prevent RedisVL from attempting to close the client + # on an event loop in a separate thread. + self.checkpoints_index._redis_client = None + self.checkpoint_blobs_index._redis_client = None + self.checkpoint_writes_index._redis_client = None + @classmethod @asynccontextmanager async def from_conn_string( @@ -112,18 +144,12 @@ async def from_conn_string( connection_args: Optional[dict[str, Any]] = None, ) -> AsyncIterator[AsyncShallowRedisSaver]: """Create a new AsyncShallowRedisSaver instance.""" - saver: Optional[AsyncShallowRedisSaver] = None - try: - saver = cls( - redis_url=redis_url, - redis_client=redis_client, - connection_args=connection_args, - ) + async with cls( + redis_url=redis_url, + redis_client=redis_client, + connection_args=connection_args, + ) as saver: yield saver - finally: - if saver and saver._owns_its_client: - await saver._redis.aclose() # type: ignore[attr-defined] - await saver._redis.connection_pool.disconnect() async def asetup(self) -> None: """Initialize Redis indexes asynchronously.""" @@ -317,9 +343,9 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: # Ensure metadata matches CheckpointMetadata type sanitized_metadata = { - k.replace("\u0000", ""): v.replace("\u0000", "") - if isinstance(v, str) - else v + k.replace("\u0000", ""): ( + v.replace("\u0000", "") if isinstance(v, str) else v + ) for k, v in metadata_dict.items() } metadata = cast(CheckpointMetadata, sanitized_metadata) @@ -386,37 +412,22 @@ async def aput_writes( } writes_objects.append(write_obj) - # For each write, check existence and then perform appropriate operation - async with self.checkpoints_index.client.pipeline( - transaction=False - ) as pipeline: - for write_obj in writes_objects: - key = self._make_redis_checkpoint_writes_key( - thread_id, - checkpoint_ns, - checkpoint_id, - task_id, - write_obj["idx"], - ) - - # First check if key exists - key_exists = await self._redis.exists(key) == 1 - - if all(w[0] in WRITES_IDX_MAP for w in writes): - # UPSERT case - only update specific fields - if key_exists: - # Update only channel, type, and blob fields - pipeline.json().set(key, "$.channel", write_obj["channel"]) - pipeline.json().set(key, "$.type", write_obj["type"]) - pipeline.json().set(key, "$.blob", write_obj["blob"]) - else: - # For new records, set the complete object - pipeline.json().set(key, "$", write_obj) - else: - # INSERT case - only insert if doesn't exist - pipeline.json().set(key, "$", write_obj) - - await pipeline.execute() + upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes) + for write_obj in writes_objects: + key = self._make_redis_checkpoint_writes_key( + thread_id, + checkpoint_ns, + checkpoint_id, + task_id, + write_obj["idx"], + ) + if upsert_case: + tx = partial(_write_obj_tx, key=key, write_obj=write_obj) + await self._redis.transaction(tx, key) + else: + # Unlike AsyncRedisSaver, the shallow implementation always overwrites + # the full object, so we don't check for key existence here. + await self._redis.json().set(key, "$", write_obj) async def aget_channel_values( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str diff --git a/langgraph/docs/create-react-agent-memory.ipynb b/langgraph/docs/create-react-agent-memory.ipynb index 994cf63..fdd55be 100644 --- a/langgraph/docs/create-react-agent-memory.ipynb +++ b/langgraph/docs/create-react-agent-memory.ipynb @@ -292,3 +292,4 @@ "nbformat": 4, "nbformat_minor": 5 } + diff --git a/langgraph/store/redis/aio.py b/langgraph/store/redis/aio.py index f578217..fa88e8f 100644 --- a/langgraph/store/redis/aio.py +++ b/langgraph/store/redis/aio.py @@ -179,23 +179,12 @@ async def from_conn_string( index: Optional[IndexConfig] = None, ) -> AsyncIterator[AsyncRedisStore]: """Create store from Redis connection string.""" - store = cls(redis_url=conn_string, index=index) - try: + async with cls(redis_url=conn_string, index=index) as store: store._task = store.loop.create_task( store._run_background_tasks(store._aqueue, weakref.ref(store)) ) await store.setup() yield store - finally: - if hasattr(store, "_task"): - store._task.cancel() - try: - await store._task - except asyncio.CancelledError: - pass - if store._owns_client: - await store._redis.aclose() # type: ignore[attr-defined] - await store._redis.connection_pool.disconnect() def create_indexes(self) -> None: """Create async indices.""" @@ -221,8 +210,9 @@ async def __aexit__( except asyncio.CancelledError: pass - # if self._owns_client: - await self._redis.aclose() # type: ignore[attr-defined] + if self._owns_client: + await self._redis.aclose() # type: ignore[attr-defined] + await self._redis.connection_pool.disconnect() async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute batch of operations asynchronously.""" @@ -306,6 +296,7 @@ async def _batch_get_ops( key_to_row = { json.loads(doc.json)["key"]: json.loads(doc.json) for doc in res.docs } + for idx, key in items: if key in key_to_row: results[idx] = _row_to_item(namespace, key_to_row[key]) @@ -482,7 +473,7 @@ async def _batch_search_ops( ) # Get matching store docs in pipeline - pipeline = self._redis.pipeline() + pipeline = self._redis.pipeline(transaction=False) result_map = {} # Map store key to vector result with distances for doc in vector_results: diff --git a/tests/conftest.py b/tests/conftest.py index 1546898..2d61c0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,3 +48,31 @@ async def clear_redis(redis_url: str) -> None: client = Redis.from_url(redis_url) await client.flushall() await client.aclose() # type: ignore[attr-defined] + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--run-api-tests", + action="store_true", + default=False, + help="Run tests that require API keys", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line( + "markers", "requires_api_keys: mark test as requiring API keys" + ) + + +def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item] +) -> None: + if config.getoption("--run-api-tests"): + return + skip_api = pytest.mark.skip( + reason="Skipping test because API keys are not provided. Use --run-api-tests to run these tests." + ) + for item in items: + if item.get_closest_marker("requires_api_keys"): + item.add_marker(skip_api) diff --git a/tests/test_async.py b/tests/test_async.py index 500273d..ed535fa 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -539,6 +539,7 @@ def model() -> ChatOpenAI: return ChatOpenAI(model="gpt-4-turbo-preview", temperature=0) +@pytest.mark.requires_api_keys @pytest.mark.asyncio async def test_async_redis_checkpointer( redis_url: str, tools: List[BaseTool], model: ChatOpenAI diff --git a/tests/test_async_store.py b/tests/test_async_store.py index ad266ff..0168e3a 100644 --- a/tests/test_async_store.py +++ b/tests/test_async_store.py @@ -480,6 +480,7 @@ async def test_large_batches(store: AsyncRedisStore) -> None: ) +@pytest.mark.requires_api_keys @pytest.mark.asyncio async def test_async_store_with_memory_persistence( redis_url: str, @@ -501,8 +502,11 @@ async def test_async_store_with_memory_persistence( "distance_type": "cosine", } - async with AsyncRedisStore.from_conn_string(redis_url, index=index_config) as store: + async with AsyncRedisStore.from_conn_string( + redis_url, index=index_config + ) as store, AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: await store.setup() + await checkpointer.asetup() model = ChatAnthropic(model="claude-3-5-sonnet-20240620") # type: ignore[call-arg] @@ -532,11 +536,6 @@ def call_model( builder.add_node("call_model", call_model) # type:ignore[arg-type] builder.add_edge(START, "call_model") - checkpointer = None - async with AsyncRedisSaver.from_conn_string(redis_url) as cp: - await cp.asetup() - checkpointer = cp - # Compile graph with store and checkpointer graph = builder.compile(checkpointer=checkpointer, store=store) diff --git a/tests/test_store.py b/tests/test_store.py index cad1146..fe25e10 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -426,6 +426,7 @@ def test_vector_update_with_score_verification( assert not any(r.key == "doc4" for r in results_new) +@pytest.mark.requires_api_keys def test_store_with_memory_persistence(redis_url: str) -> None: """Test store functionality with memory persistence. diff --git a/tests/test_sync.py b/tests/test_sync.py index 174a1d0..9e4b0a4 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -434,6 +434,7 @@ def model() -> ChatOpenAI: return ChatOpenAI(model="gpt-4-turbo-preview", temperature=0) +@pytest.mark.requires_api_keys def test_sync_redis_checkpointer( tools: list[BaseTool], model: ChatOpenAI, redis_url: str ) -> None: