|
16 | 16 | Optional, |
17 | 17 | Union, |
18 | 18 | ) |
| 19 | +import weakref |
19 | 20 |
|
20 | 21 | from redisvl.utils.utils import deprecated_argument, deprecated_function |
21 | 22 |
|
@@ -784,6 +785,12 @@ def info(self, name: Optional[str] = None) -> Dict[str, Any]: |
784 | 785 | index_name = name or self.schema.index.name |
785 | 786 | return self._info(index_name, self._redis_client) # type: ignore |
786 | 787 |
|
| 788 | + def __enter__(self): |
| 789 | + return self |
| 790 | + |
| 791 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 792 | + self.disconnect() |
| 793 | + |
787 | 794 |
|
788 | 795 | class AsyncSearchIndex(BaseSearchIndex): |
789 | 796 | """A search index class for interacting with Redis as a vector database in |
@@ -857,11 +864,8 @@ def __init__( |
857 | 864 | self._connection_kwargs = connection_kwargs or {} |
858 | 865 | self._lock = asyncio.Lock() |
859 | 866 |
|
860 | | - async def disconnect(self): |
861 | | - """Asynchronously disconnect and cleanup the underlying async redis connection.""" |
862 | | - if self._redis_client is not None: |
863 | | - await self._redis_client.aclose() # type: ignore |
864 | | - self._redis_client = None |
| 867 | + # Close connections when the object is garbage collected |
| 868 | + weakref.finalize(self, self._finalize_disconnect) |
865 | 869 |
|
866 | 870 | @classmethod |
867 | 871 | async def from_existing( |
@@ -1336,9 +1340,36 @@ async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: |
1336 | 1340 | raise RedisSearchError( |
1337 | 1341 | f"Error while fetching {name} index info: {str(e)}" |
1338 | 1342 | ) from e |
| 1343 | + |
| 1344 | + async def disconnect(self): |
| 1345 | + """Asynchronously disconnect and cleanup the underlying async redis connection.""" |
| 1346 | + if self._redis_client is not None: |
| 1347 | + await self._redis_client.aclose() # type: ignore |
| 1348 | + self._redis_client = None |
| 1349 | + |
| 1350 | + def disconnect_sync(self): |
| 1351 | + """Synchronously disconnect and cleanup the underlying async redis connection.""" |
| 1352 | + if self._redis_client is None: |
| 1353 | + return |
| 1354 | + loop = asyncio.get_running_loop() |
| 1355 | + if loop is None or not loop.is_running(): |
| 1356 | + asyncio.run(self._redis_client.aclose()) # type: ignore |
| 1357 | + else: |
| 1358 | + loop.create_task(self.disconnect()) |
| 1359 | + self._redis_client = None |
1339 | 1360 |
|
1340 | 1361 | async def __aenter__(self): |
1341 | 1362 | return self |
1342 | 1363 |
|
1343 | 1364 | async def __aexit__(self, exc_type, exc_val, exc_tb): |
1344 | 1365 | await self.disconnect() |
| 1366 | + |
| 1367 | + def _finalize_disconnect(self): |
| 1368 | + try: |
| 1369 | + loop = asyncio.get_running_loop() |
| 1370 | + except RuntimeError: |
| 1371 | + loop = None |
| 1372 | + if loop is None or not loop.is_running(): |
| 1373 | + asyncio.run(self.disconnect()) |
| 1374 | + else: |
| 1375 | + loop.create_task(self.disconnect()) |
0 commit comments