Skip to content

Commit a11ceda

Browse files
unleash the caching on the subclass vectorizers
1 parent 2e8b167 commit a11ceda

File tree

19 files changed

+1879
-1344
lines changed

19 files changed

+1879
-1344
lines changed

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

redisvl/extensions/cache/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,5 @@
66
"""
77

88
from redisvl.extensions.cache.base import BaseCache
9-
from redisvl.extensions.cache.embeddings import EmbeddingsCache
10-
from redisvl.extensions.cache.llm import SemanticCache
119

12-
__all__ = ["BaseCache", "EmbeddingsCache", "SemanticCache"]
10+
__all__ = ["BaseCache"]

redisvl/extensions/cache/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from redis import Redis
1010
from redis.asyncio import Redis as AsyncRedis
1111

12+
from redisvl.redis.connection import RedisConnectionFactory
13+
1214

1315
class BaseCache:
1416
"""Base abstract cache interface for all RedisVL caches.
@@ -121,10 +123,15 @@ async def _get_async_redis_client(self) -> AsyncRedis:
121123
AsyncRedis: An async Redis client instance.
122124
"""
123125
if not hasattr(self, "_async_redis_client") or self._async_redis_client is None:
124-
# Create new async Redis client
125-
url = self.redis_kwargs["redis_url"]
126-
kwargs = self.redis_kwargs["connection_kwargs"]
127-
self._async_redis_client = AsyncRedis.from_url(url, **kwargs) # type: ignore
126+
client = self.redis_kwargs.get("redis_client")
127+
if isinstance(client, Redis):
128+
self._async_redis_client = RedisConnectionFactory.sync_to_async_redis(
129+
client
130+
)
131+
else:
132+
url = self.redis_kwargs["redis_url"]
133+
kwargs = self.redis_kwargs["connection_kwargs"]
134+
self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore
128135
return self._async_redis_client
129136

130137
def expire(self, key: str, ttl: Optional[int] = None) -> None:

redisvl/extensions/cache/llm/semantic.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from redisvl.index import AsyncSearchIndex, SearchIndex
2323
from redisvl.query import VectorRangeQuery
2424
from redisvl.query.filter import FilterExpression
25-
from redisvl.redis.connection import RedisConnectionFactory
2625
from redisvl.redis.utils import hashify
2726
from redisvl.utils.log import get_logger
2827
from redisvl.utils.utils import (
@@ -31,7 +30,7 @@
3130
serialize,
3231
validate_vector_dims,
3332
)
34-
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
33+
from redisvl.utils.vectorize.base import BaseVectorizer
3534

3635
logger = get_logger("[RedisVL]")
3736

@@ -105,6 +104,8 @@ def __init__(
105104
)
106105
self._vectorizer = vectorizer
107106
else:
107+
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
108+
108109
# Create a default vectorizer
109110
vectorizer_kwargs = kwargs
110111
if dtype:
@@ -183,11 +184,8 @@ def _modify_schema(
183184
async def _get_async_index(self) -> AsyncSearchIndex:
184185
"""Lazily construct the async search index class."""
185186
# Construct async index if necessary
186-
async_client = None
187187
if self._aindex is None:
188-
client = self.redis_kwargs.get("redis_client")
189-
if isinstance(client, Redis):
190-
async_client = RedisConnectionFactory.sync_to_async_redis(client)
188+
async_client = await self._get_async_redis_client()
191189
self._aindex = AsyncSearchIndex(
192190
schema=self._index.schema,
193191
redis_client=async_client,

redisvl/extensions/router/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class RoutingConfig(BaseModel):
6262
"""Configuration for routing behavior."""
6363

6464
"""The maximum number of top matches to return."""
65-
max_k: Annotated[int, Field(strict=True, default=1, gt=0)] = 1
65+
max_k: Annotated[int, Field(strict=True, gt=0)] = 1
6666
"""Aggregation method to use to classify queries."""
6767
aggregation_method: DistanceAggregationMethod = Field(
6868
default=DistanceAggregationMethod.avg

redisvl/extensions/router/semantic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@
2121
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2222
from redisvl.utils.log import get_logger
2323
from redisvl.utils.utils import deprecated_argument, model_to_dict
24-
from redisvl.utils.vectorize import (
25-
BaseVectorizer,
26-
HFTextVectorizer,
27-
vectorizer_from_dict,
28-
)
24+
from redisvl.utils.vectorize.base import BaseVectorizer
25+
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
2926

3027
logger = get_logger(__name__)
3128

@@ -483,6 +480,8 @@ def from_dict(
483480
}
484481
router = SemanticRouter.from_dict(router_data)
485482
"""
483+
from redisvl.utils.vectorize import vectorizer_from_dict
484+
486485
try:
487486
name = data["name"]
488487
routes_data = data["routes"]

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from redisvl.query import FilterQuery, RangeQuery
2121
from redisvl.query.filter import Tag
2222
from redisvl.utils.utils import deprecated_argument, validate_vector_dims
23-
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
23+
from redisvl.utils.vectorize.base import BaseVectorizer
2424

2525

2626
class SemanticSessionManager(BaseSessionManager):
@@ -82,6 +82,8 @@ def __init__(
8282
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
8383
)
8484
else:
85+
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
86+
8587
vectorizer_kwargs = kwargs
8688

8789
if dtype:

0 commit comments

Comments
 (0)