Skip to content

Commit 3eb1d43

Browse files
committed
add finalizers, explicit disconnect methods, ctx mgrs
1 parent 2b171d0 commit 3eb1d43

File tree

5 files changed

+161
-6
lines changed

5 files changed

+161
-6
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
22
from typing import Any, Dict, List, Optional
3+
import weakref
34

45
from redis import Redis
5-
from redis.asyncio import Redis as AsyncRedis
66
from redisvl.extensions.constants import (
77
CACHE_VECTOR_FIELD_NAME,
88
ENTRY_ID_FIELD_NAME,
@@ -23,6 +23,7 @@
2323
from redisvl.query import RangeQuery
2424
from redisvl.query.filter import FilterExpression
2525
from redisvl.redis.connection import RedisConnectionFactory
26+
from redisvl.utils.log import get_logger
2627
from redisvl.utils.utils import (
2728
current_timestamp,
2829
deprecated_argument,
@@ -32,6 +33,9 @@
3233
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
3334

3435

36+
logger = get_logger("[RedisVL]")
37+
38+
3539
class SemanticCache(BaseLLMCache):
3640
"""Semantic Cache for Large Language Models."""
3741

@@ -149,6 +153,8 @@ def __init__(
149153

150154
# Create the search index in Redis
151155
self._index.create(overwrite=overwrite, drop=False)
156+
157+
weakref.finalize(self, self._finalize_async)
152158

153159
def _modify_schema(
154160
self,
@@ -702,3 +708,37 @@ async def aupdate(self, key: str, **kwargs) -> None:
702708
await aindex.load(data=[kwargs], keys=[key])
703709

704710
await self._async_refresh_ttl(key)
711+
712+
def _finalize_async(self):
713+
if self._index:
714+
self._index.disconnect()
715+
if self._aindex:
716+
try:
717+
loop = None
718+
try:
719+
loop = asyncio.get_running_loop()
720+
except RuntimeError:
721+
loop = asyncio.new_event_loop()
722+
asyncio.set_event_loop(loop)
723+
loop.run_until_complete(self._aindex.disconnect())
724+
except Exception as e:
725+
logger.info(f"Error disconnecting from index: {e}")
726+
727+
def disconnect(self):
728+
if self._index:
729+
self._index.disconnect()
730+
if self._aindex:
731+
asyncio.run(self._aindex.disconnect())
732+
733+
async def adisconnect(self):
734+
if self._index:
735+
self._index.disconnect()
736+
if self._aindex:
737+
await self._aindex.disconnect()
738+
self._aindex = None
739+
740+
async def __aenter__(self):
741+
return self
742+
743+
async def __aexit__(self, exc_type, exc_val, exc_tb):
744+
await self.adisconnect()

redisvl/index/index.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Optional,
1717
Union,
1818
)
19+
import weakref
1920

2021
from redisvl.utils.utils import deprecated_argument, deprecated_function
2122

@@ -784,6 +785,12 @@ def info(self, name: Optional[str] = None) -> Dict[str, Any]:
784785
index_name = name or self.schema.index.name
785786
return self._info(index_name, self._redis_client) # type: ignore
786787

788+
def __enter__(self):
789+
return self
790+
791+
def __exit__(self, exc_type, exc_val, exc_tb):
792+
self.disconnect()
793+
787794

788795
class AsyncSearchIndex(BaseSearchIndex):
789796
"""A search index class for interacting with Redis as a vector database in
@@ -857,11 +864,8 @@ def __init__(
857864
self._connection_kwargs = connection_kwargs or {}
858865
self._lock = asyncio.Lock()
859866

860-
async def disconnect(self):
861-
"""Asynchronously disconnect and cleanup the underlying async redis connection."""
862-
if self._redis_client is not None:
863-
await self._redis_client.aclose() # type: ignore
864-
self._redis_client = None
867+
# Close connections when the object is garbage collected
868+
weakref.finalize(self, self._finalize_disconnect)
865869

866870
@classmethod
867871
async def from_existing(
@@ -1336,9 +1340,36 @@ async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
13361340
raise RedisSearchError(
13371341
f"Error while fetching {name} index info: {str(e)}"
13381342
) from e
1343+
1344+
async def disconnect(self):
1345+
"""Asynchronously disconnect and cleanup the underlying async redis connection."""
1346+
if self._redis_client is not None:
1347+
await self._redis_client.aclose() # type: ignore
1348+
self._redis_client = None
1349+
1350+
def disconnect_sync(self):
1351+
"""Synchronously disconnect and cleanup the underlying async redis connection."""
1352+
if self._redis_client is None:
1353+
return
1354+
loop = asyncio.get_running_loop()
1355+
if loop is None or not loop.is_running():
1356+
asyncio.run(self._redis_client.aclose()) # type: ignore
1357+
else:
1358+
loop.create_task(self.disconnect())
1359+
self._redis_client = None
13391360

13401361
async def __aenter__(self):
13411362
return self
13421363

13431364
async def __aexit__(self, exc_type, exc_val, exc_tb):
13441365
await self.disconnect()
1366+
1367+
def _finalize_disconnect(self):
1368+
try:
1369+
loop = asyncio.get_running_loop()
1370+
except RuntimeError:
1371+
loop = None
1372+
if loop is None or not loop.is_running():
1373+
asyncio.run(self.disconnect())
1374+
else:
1375+
loop.create_task(self.disconnect())

tests/integration/test_async_search_index.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,33 @@ async def test_check_index_exists_before_info(async_index):
313313

314314
with pytest.raises(RedisSearchError):
315315
await async_index.info()
316+
317+
318+
@pytest.mark.asyncio
319+
async def test_search_index_async_context_manager(async_index):
320+
async with async_index:
321+
await async_index.create(overwrite=True, drop=True)
322+
assert async_index._redis_client
323+
assert async_index._redis_client is None
324+
325+
326+
@pytest.mark.asyncio
327+
async def test_search_index_context_manager_with_exception(async_index):
328+
async with async_index:
329+
await async_index.create(overwrite=True, drop=True)
330+
raise ValueError("test")
331+
assert async_index._redis_client is None
332+
333+
334+
@pytest.mark.asyncio
335+
async def test_search_index_disconnect(async_index):
336+
await async_index.create(overwrite=True, drop=True)
337+
await async_index.disconnect()
338+
assert async_index._redis_client is None
339+
340+
341+
@pytest.mark.asyncio
342+
async def test_search_engine_disconnect_sync(async_index):
343+
await async_index.create(overwrite=True, drop=True)
344+
async_index.disconnect_sync()
345+
assert async_index._redis_client is None

tests/integration/test_llmcache.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,3 +941,36 @@ def test_deprecated_dtype_argument(redis_url):
941941
redis_url=redis_url,
942942
overwrite=True,
943943
)
944+
945+
946+
947+
@pytest.mark.asyncio
948+
async def test_cache_async_context_manager(redis_url):
949+
async with SemanticCache(name="test_cache", redis_url=redis_url) as cache:
950+
await cache.astore("test prompt", "test response")
951+
assert cache._aindex
952+
assert cache._aindex is None
953+
954+
955+
@pytest.mark.asyncio
956+
async def test_cache_async_context_manager_with_exception(redis_url):
957+
async with SemanticCache(name="test_cache", redis_url=redis_url) as cache:
958+
await cache.astore("test prompt", "test response")
959+
raise ValueError("test")
960+
assert cache._aindex is None
961+
962+
963+
@pytest.mark.asyncio
964+
async def test_cache_async_disconnect(redis_url):
965+
cache = SemanticCache(name="test_cache", redis_url=redis_url)
966+
await cache.astore("test prompt", "test response")
967+
await cache.adisconnect()
968+
assert cache._aindex is None
969+
970+
971+
def test_cache_disconnect(redis_url):
972+
cache = SemanticCache(name="test_cache", redis_url=redis_url)
973+
cache.store("test prompt", "test response")
974+
cache.disconnect()
975+
# We keep this index object around because it isn't lazily created
976+
assert cache._index.client is None

tests/integration/test_search_index.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,24 @@ def test_check_index_exists_before_info(index):
287287
def test_index_needs_valid_schema():
288288
with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"):
289289
SearchIndex(schema="Not A Valid Schema") # type: ignore
290+
291+
292+
def test_search_index_context_manager(index):
293+
with index:
294+
index.create(overwrite=True, drop=True)
295+
assert index.client
296+
assert index.client is None
297+
298+
299+
def test_search_index_context_manager_with_exception(index):
300+
with pytest.raises(ValueError):
301+
with index:
302+
index.create(overwrite=True, drop=True)
303+
raise ValueError("test")
304+
assert index.client is None
305+
306+
307+
def test_search_index_disconnect(index):
308+
index.create(overwrite=True, drop=True)
309+
index.disconnect()
310+
assert index.client is None

0 commit comments

Comments
 (0)