11import asyncio
22import json
3+ import logging
34from os import replace
5+ from re import S
46import threading
57import warnings
68from functools import wraps
9+
710from typing import (
811 TYPE_CHECKING ,
912 Any ,
1821)
1922import weakref
2023
21- from redisvl .utils .utils import deprecated_argument , deprecated_function
24+ from redisvl .utils .utils import deprecated_argument , deprecated_function , sync_wrapper
2225
2326if TYPE_CHECKING :
2427 from redis .commands .search .aggregation import AggregateResult
2528 from redis .commands .search .document import Document
2629 from redis .commands .search .result import Result
2730 from redisvl .query .query import BaseQuery
28- import redis .asyncio
2931
3032import redis
3133import redis .asyncio as aredis
3840from redisvl .redis .connection import (
3941 RedisConnectionFactory ,
4042 convert_index_info_to_schema ,
41- validate_modules ,
4243)
4344from redisvl .redis .utils import convert_bytes
4445from redisvl .schema import IndexSchema , StorageType
@@ -279,14 +280,21 @@ def __init__(
279280
280281 self ._lib_name : Optional [str ] = kwargs .pop ("lib_name" , None )
281282
282- # Store connection parameters
283+ # Store connection parameters
283284 self .__redis_client = redis_client
284285 self ._redis_url = redis_url
285286 self ._connection_kwargs = connection_kwargs or {}
286- self ._lock = threading .Lock ()
287+ self ._lock = threading .Lock ()
288+
289+ self ._owns_redis_client = redis_client is None
290+ if self ._owns_redis_client :
291+ weakref .finalize (self , self .disconnect )
287292
288293 def disconnect (self ):
289294 """Disconnect from the Redis database."""
295+ if self ._owns_redis_client is False :
296+ print ("Index does not own client, not disconnecting" )
297+ return
290298 if self .__redis_client :
291299 self .__redis_client .close ()
292300 self .__redis_client = None
@@ -343,12 +351,12 @@ def from_existing(
343351 def client (self ) -> Optional [redis .Redis ]:
344352 """The underlying redis-py client object."""
345353 return self .__redis_client
346-
354+
347355 @property
348356 def _redis_client (self ) -> Optional [redis .Redis ]:
349357 """
350358 Get a Redis client instance.
351-
359+
352360 Lazily creates a Redis client instance if it doesn't exist.
353361 """
354362 if self .__redis_client is None :
@@ -359,7 +367,6 @@ def _redis_client(self) -> Optional[redis.Redis]:
359367 ** self ._connection_kwargs ,
360368 )
361369 return self .__redis_client
362-
363370
364371 @deprecated_function ("connect" , "Pass connection parameters in __init__." )
365372 def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
@@ -371,8 +378,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
371378
372379 Args:
373380 redis_url (Optional[str], optional): The URL of the Redis server to
374- connect to. If not provided, the method defaults to using the
375- `REDIS_URL` environment variable.
381+ connect to.
376382
377383 Raises:
378384 redis.exceptions.ConnectionError: If the connection to the Redis
@@ -842,9 +848,9 @@ def __init__(
842848 schema (IndexSchema): Index schema object.
843849 redis_url (Optional[str], optional): The URL of the Redis server to
844850 connect to.
845- redis_client (Optional[aredis.Redis], optional ): An
851+ redis_client (Optional[aredis.Redis]): An
846852 instantiated redis client.
847- connection_kwargs (Dict[str, Any], optional ): Redis client connection
853+ connection_kwargs (Optional[ Dict[str, Any]] ): Redis client connection
848854 args.
849855 """
850856 if "redis_kwargs" in kwargs :
@@ -864,8 +870,9 @@ def __init__(
864870 self ._connection_kwargs = connection_kwargs or {}
865871 self ._lock = asyncio .Lock ()
866872
867- # Close connections when the object is garbage collected
868- weakref .finalize (self , self ._finalize_disconnect )
873+ self ._owns_redis_client = redis_client is None
874+ if self ._owns_redis_client :
875+ weakref .finalize (self , sync_wrapper (self .disconnect ))
869876
870877 @classmethod
871878 async def from_existing (
@@ -934,7 +941,7 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs):
934941 await self .set_client (client )
935942
936943 @deprecated_function ("set_client" , "Pass connection parameters in __init__." )
937- async def set_client (self , redis_client : aredis .Redis ):
944+ async def set_client (self , redis_client : Union [ aredis .Redis , redis . Redis ] ):
938945 """
939946 [DEPRECATED] Manually set the Redis client to use with the search index.
940947 This method is deprecated; please provide connection parameters in __init__.
@@ -956,16 +963,17 @@ async def _get_client(self) -> aredis.Redis:
956963 kwargs ["url" ] = self ._redis_url
957964 self ._redis_client = (
958965 await RedisConnectionFactory ._get_aredis_connection (
959- required_modules = self .required_modules ,
960- ** kwargs
966+ required_modules = self .required_modules , ** kwargs
961967 )
962968 )
963969 await RedisConnectionFactory .validate_async_redis (
964970 self ._redis_client , self ._lib_name
965971 )
966972 return self ._redis_client
967973
968- async def _validate_client (self , redis_client : aredis .Redis ) -> aredis .Redis :
974+ async def _validate_client (
975+ self , redis_client : Union [aredis .Redis , redis .Redis ]
976+ ) -> aredis .Redis :
969977 if isinstance (redis_client , redis .Redis ):
970978 warnings .warn (
971979 "Converting sync Redis client to async client is deprecated "
@@ -1340,36 +1348,21 @@ async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
13401348 raise RedisSearchError (
13411349 f"Error while fetching { name } index info: { str (e )} "
13421350 ) from e
1343-
1351+
13441352 async def disconnect (self ):
1345- """Asynchronously disconnect and cleanup the underlying async redis connection."""
1353+ if self ._owns_redis_client is False :
1354+ return
13461355 if self ._redis_client is not None :
13471356 await self ._redis_client .aclose () # type: ignore
13481357 self ._redis_client = None
13491358
13501359 def disconnect_sync (self ):
1351- """Synchronously disconnect and cleanup the underlying async redis connection."""
1352- if self ._redis_client is None :
1360+ if self ._redis_client is None or self ._owns_redis_client is False :
13531361 return
1354- loop = asyncio .get_running_loop ()
1355- if loop is None or not loop .is_running ():
1356- asyncio .run (self ._redis_client .aclose ()) # type: ignore
1357- else :
1358- loop .create_task (self .disconnect ())
1359- self ._redis_client = None
1362+ sync_wrapper (self .disconnect )()
13601363
13611364 async def __aenter__ (self ):
13621365 return self
13631366
13641367 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
13651368 await self .disconnect ()
1366-
1367- def _finalize_disconnect (self ):
1368- try :
1369- loop = asyncio .get_running_loop ()
1370- except RuntimeError :
1371- loop = None
1372- if loop is None or not loop .is_running ():
1373- asyncio .run (self .disconnect ())
1374- else :
1375- loop .create_task (self .disconnect ())
0 commit comments