Skip to content

Commit e9470d3

Browse files
committed
Rejigger the finalizers and more
1 parent 3eb1d43 commit e9470d3

File tree

6 files changed

+300
-161
lines changed

6 files changed

+300
-161
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from redisvl.redis.connection import RedisConnectionFactory
2626
from redisvl.utils.log import get_logger
2727
from redisvl.utils.utils import (
28+
sync_wrapper,
2829
current_timestamp,
2930
deprecated_argument,
3031
serialize,
@@ -133,6 +134,12 @@ def __init__(
133134
name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore
134135
)
135136
schema = self._modify_schema(schema, filterable_fields)
137+
138+
if redis_client:
139+
self._owns_redis_client = False
140+
else:
141+
self._owns_redis_client = True
142+
136143
self._index = SearchIndex(
137144
schema=schema,
138145
redis_client=redis_client,
@@ -153,8 +160,6 @@ def __init__(
153160

154161
# Create the search index in Redis
155162
self._index.create(overwrite=overwrite, drop=False)
156-
157-
weakref.finalize(self, self._finalize_async)
158163

159164
def _modify_schema(
160165
self,
@@ -317,7 +322,9 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
317322
def _check_vector_dims(self, vector: List[float]):
318323
"""Checks the size of the provided vector and raises an error if it
319324
doesn't match the search index vector dimensions."""
320-
schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore
325+
schema_vector_dims = self._index.schema.fields[
326+
CACHE_VECTOR_FIELD_NAME
327+
].attrs.dims # type: ignore
321328
validate_vector_dims(len(vector), schema_vector_dims)
322329

323330
def check(
@@ -392,7 +399,8 @@ def check(
392399
# Search the cache!
393400
cache_search_results = self._index.query(query)
394401
redis_keys, cache_hits = self._process_cache_results(
395-
cache_search_results, return_fields # type: ignore
402+
cache_search_results,
403+
return_fields, # type: ignore
396404
)
397405
# Extend TTL on keys
398406
for key in redis_keys:
@@ -473,7 +481,8 @@ async def acheck(
473481
# Search the cache!
474482
cache_search_results = await aindex.query(query)
475483
redis_keys, cache_hits = self._process_cache_results(
476-
cache_search_results, return_fields # type: ignore
484+
cache_search_results,
485+
return_fields, # type: ignore
477486
)
478487
# Extend TTL on keys
479488
await asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys])
@@ -646,7 +655,6 @@ def update(self, key: str, **kwargs) -> None:
646655
"""
647656
if kwargs:
648657
for k, v in kwargs.items():
649-
650658
# Make sure the item is in the index schema
651659
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
652660
raise ValueError(f"{k} is not a valid field within the cache entry")
@@ -689,7 +697,6 @@ async def aupdate(self, key: str, **kwargs) -> None:
689697

690698
if kwargs:
691699
for k, v in kwargs.items():
692-
693700
# Make sure the item is in the index schema
694701
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
695702
raise ValueError(f"{k} is not a valid field within the cache entry")
@@ -708,29 +715,18 @@ async def aupdate(self, key: str, **kwargs) -> None:
708715
await aindex.load(data=[kwargs], keys=[key])
709716

710717
await self._async_refresh_ttl(key)
711-
712-
def _finalize_async(self):
713-
if self._index:
714-
self._index.disconnect()
715-
if self._aindex:
716-
try:
717-
loop = None
718-
try:
719-
loop = asyncio.get_running_loop()
720-
except RuntimeError:
721-
loop = asyncio.new_event_loop()
722-
asyncio.set_event_loop(loop)
723-
loop.run_until_complete(self._aindex.disconnect())
724-
except Exception as e:
725-
logger.info(f"Error disconnecting from index: {e}")
726718

727719
def disconnect(self):
720+
if self._owns_redis_client is False:
721+
return
728722
if self._index:
729723
self._index.disconnect()
730724
if self._aindex:
731-
asyncio.run(self._aindex.disconnect())
725+
self._aindex.disconnect_sync()
732726

733727
async def adisconnect(self):
728+
if not self._owns_redis_client:
729+
return
734730
if self._index:
735731
self._index.disconnect()
736732
if self._aindex:

redisvl/index/index.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import asyncio
22
import json
3+
import logging
34
from os import replace
5+
from re import S
46
import threading
57
import warnings
68
from functools import wraps
9+
710
from typing import (
811
TYPE_CHECKING,
912
Any,
@@ -18,14 +21,13 @@
1821
)
1922
import weakref
2023

21-
from redisvl.utils.utils import deprecated_argument, deprecated_function
24+
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2225

2326
if TYPE_CHECKING:
2427
from redis.commands.search.aggregation import AggregateResult
2528
from redis.commands.search.document import Document
2629
from redis.commands.search.result import Result
2730
from redisvl.query.query import BaseQuery
28-
import redis.asyncio
2931

3032
import redis
3133
import redis.asyncio as aredis
@@ -38,7 +40,6 @@
3840
from redisvl.redis.connection import (
3941
RedisConnectionFactory,
4042
convert_index_info_to_schema,
41-
validate_modules,
4243
)
4344
from redisvl.redis.utils import convert_bytes
4445
from redisvl.schema import IndexSchema, StorageType
@@ -279,14 +280,21 @@ def __init__(
279280

280281
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
281282

282-
# Store connection parameters
283+
# Store connection parameters
283284
self.__redis_client = redis_client
284285
self._redis_url = redis_url
285286
self._connection_kwargs = connection_kwargs or {}
286-
self._lock = threading.Lock()
287+
self._lock = threading.Lock()
288+
289+
self._owns_redis_client = redis_client is None
290+
if self._owns_redis_client:
291+
weakref.finalize(self, self.disconnect)
287292

288293
def disconnect(self):
289294
"""Disconnect from the Redis database."""
295+
if self._owns_redis_client is False:
296+
print("Index does not own client, not disconnecting")
297+
return
290298
if self.__redis_client:
291299
self.__redis_client.close()
292300
self.__redis_client = None
@@ -343,12 +351,12 @@ def from_existing(
343351
def client(self) -> Optional[redis.Redis]:
344352
"""The underlying redis-py client object."""
345353
return self.__redis_client
346-
354+
347355
@property
348356
def _redis_client(self) -> Optional[redis.Redis]:
349357
"""
350358
Get a Redis client instance.
351-
359+
352360
Lazily creates a Redis client instance if it doesn't exist.
353361
"""
354362
if self.__redis_client is None:
@@ -359,7 +367,6 @@ def _redis_client(self) -> Optional[redis.Redis]:
359367
**self._connection_kwargs,
360368
)
361369
return self.__redis_client
362-
363370

364371
@deprecated_function("connect", "Pass connection parameters in __init__.")
365372
def connect(self, redis_url: Optional[str] = None, **kwargs):
@@ -371,8 +378,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
371378
372379
Args:
373380
redis_url (Optional[str], optional): The URL of the Redis server to
374-
connect to. If not provided, the method defaults to using the
375-
`REDIS_URL` environment variable.
381+
connect to.
376382
377383
Raises:
378384
redis.exceptions.ConnectionError: If the connection to the Redis
@@ -842,9 +848,9 @@ def __init__(
842848
schema (IndexSchema): Index schema object.
843849
redis_url (Optional[str], optional): The URL of the Redis server to
844850
connect to.
845-
redis_client (Optional[aredis.Redis], optional): An
851+
redis_client (Optional[aredis.Redis]): An
846852
instantiated redis client.
847-
connection_kwargs (Dict[str, Any], optional): Redis client connection
853+
connection_kwargs (Optional[Dict[str, Any]]): Redis client connection
848854
args.
849855
"""
850856
if "redis_kwargs" in kwargs:
@@ -864,8 +870,9 @@ def __init__(
864870
self._connection_kwargs = connection_kwargs or {}
865871
self._lock = asyncio.Lock()
866872

867-
# Close connections when the object is garbage collected
868-
weakref.finalize(self, self._finalize_disconnect)
873+
self._owns_redis_client = redis_client is None
874+
if self._owns_redis_client:
875+
weakref.finalize(self, sync_wrapper(self.disconnect))
869876

870877
@classmethod
871878
async def from_existing(
@@ -934,7 +941,7 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs):
934941
await self.set_client(client)
935942

936943
@deprecated_function("set_client", "Pass connection parameters in __init__.")
937-
async def set_client(self, redis_client: aredis.Redis):
944+
async def set_client(self, redis_client: Union[aredis.Redis, redis.Redis]):
938945
"""
939946
[DEPRECATED] Manually set the Redis client to use with the search index.
940947
This method is deprecated; please provide connection parameters in __init__.
@@ -956,16 +963,17 @@ async def _get_client(self) -> aredis.Redis:
956963
kwargs["url"] = self._redis_url
957964
self._redis_client = (
958965
await RedisConnectionFactory._get_aredis_connection(
959-
required_modules=self.required_modules,
960-
**kwargs
966+
required_modules=self.required_modules, **kwargs
961967
)
962968
)
963969
await RedisConnectionFactory.validate_async_redis(
964970
self._redis_client, self._lib_name
965971
)
966972
return self._redis_client
967973

968-
async def _validate_client(self, redis_client: aredis.Redis) -> aredis.Redis:
974+
async def _validate_client(
975+
self, redis_client: Union[aredis.Redis, redis.Redis]
976+
) -> aredis.Redis:
969977
if isinstance(redis_client, redis.Redis):
970978
warnings.warn(
971979
"Converting sync Redis client to async client is deprecated "
@@ -1340,36 +1348,21 @@ async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
13401348
raise RedisSearchError(
13411349
f"Error while fetching {name} index info: {str(e)}"
13421350
) from e
1343-
1351+
13441352
async def disconnect(self):
1345-
"""Asynchronously disconnect and cleanup the underlying async redis connection."""
1353+
if self._owns_redis_client is False:
1354+
return
13461355
if self._redis_client is not None:
13471356
await self._redis_client.aclose() # type: ignore
13481357
self._redis_client = None
13491358

13501359
def disconnect_sync(self):
1351-
"""Synchronously disconnect and cleanup the underlying async redis connection."""
1352-
if self._redis_client is None:
1360+
if self._redis_client is None or self._owns_redis_client is False:
13531361
return
1354-
loop = asyncio.get_running_loop()
1355-
if loop is None or not loop.is_running():
1356-
asyncio.run(self._redis_client.aclose()) # type: ignore
1357-
else:
1358-
loop.create_task(self.disconnect())
1359-
self._redis_client = None
1362+
sync_wrapper(self.disconnect)()
13601363

13611364
async def __aenter__(self):
13621365
return self
13631366

13641367
async def __aexit__(self, exc_type, exc_val, exc_tb):
13651368
await self.disconnect()
1366-
1367-
def _finalize_disconnect(self):
1368-
try:
1369-
loop = asyncio.get_running_loop()
1370-
except RuntimeError:
1371-
loop = None
1372-
if loop is None or not loop.is_running():
1373-
asyncio.run(self.disconnect())
1374-
else:
1375-
loop.create_task(self.disconnect())

redisvl/utils/utils.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
import asyncio
12
import inspect
23
import json
4+
import logging
35
import warnings
46
from contextlib import contextmanager
57
from enum import Enum
68
from functools import wraps
79
from time import time
8-
from typing import Any, Callable, Dict, Optional
10+
from typing import Any, Callable, Coroutine, Dict, Optional
911
from warnings import warn
1012

1113
from pydantic import BaseModel
1214
from ulid import ULID
1315

16+
from redisvl.utils.log import get_logger
17+
1418

1519
def create_ulid() -> str:
16-
"""Generate a unique indentifier to group related Redis documents."""
20+
"""Generate a unique identifier to group related Redis documents."""
1721
return str(ULID())
1822

1923

@@ -159,3 +163,33 @@ def wrapper(*args, **kwargs):
159163
return wrapper
160164

161165
return decorator
166+
167+
168+
def sync_wrapper(fn: Callable[[], Coroutine[Any, Any, Any]]) -> Callable[[], None]:
169+
def wrapper():
170+
try:
171+
loop = asyncio.get_running_loop()
172+
except RuntimeError:
173+
loop = None
174+
try:
175+
if loop is None or not loop.is_running():
176+
loop = asyncio.new_event_loop()
177+
asyncio.set_event_loop(loop)
178+
task = loop.create_task(fn())
179+
loop.run_until_complete(task)
180+
except RuntimeError:
181+
# This could happen if an object stored an event loop and now
182+
# that event loop is closed. There's nothing we can do other than
183+
# advise the user to use explicit cleanup methods.
184+
#
185+
# Uses logging module instead of get_logger() to avoid I/O errors
186+
# if the wrapped function is called as a finalizer.
187+
logging.info(
188+
f"Could not run the async function {fn.__name__} because the event loop is closed. "
189+
"This usually means the object was not properly cleaned up. Please use explicit "
190+
"cleanup methods (e.g., disconnect(), close()) or use the object as an async "
191+
"context manager.",
192+
)
193+
return
194+
195+
return wrapper

0 commit comments

Comments
 (0)