Skip to content

Commit e1704e1

Browse files
committed
Deprecate dtype arg to SemanticCache, SemanticRouter
Instead of using `dtype`, pass a vectorizer instance to the constructor to specify the vectorizer's dtype.
1 parent 9865faa commit e1704e1

File tree

5 files changed

+151
-20
lines changed

5 files changed

+151
-20
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from redisvl.index import AsyncSearchIndex, SearchIndex
2323
from redisvl.query import RangeQuery
2424
from redisvl.query.filter import FilterExpression
25-
from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims
25+
from redisvl.utils.utils import current_timestamp, deprecated_argument, serialize, validate_vector_dims
2626
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
2727

2828

@@ -32,6 +32,7 @@ class SemanticCache(BaseLLMCache):
3232
_index: SearchIndex
3333
_aindex: Optional[AsyncSearchIndex] = None
3434

35+
@deprecated_argument("dtype", "vectorizer")
3536
def __init__(
3637
self,
3738
name: str = "llmcache",

redisvl/extensions/router/semantic.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from redisvl.query import RangeQuery
2121
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2222
from redisvl.utils.log import get_logger
23-
from redisvl.utils.utils import model_to_dict
23+
from redisvl.utils.utils import deprecated_argument, model_to_dict
2424
from redisvl.utils.vectorize import (
2525
BaseVectorizer,
2626
HFTextVectorizer,
@@ -47,6 +47,7 @@ class SemanticRouter(BaseModel):
4747
class Config:
4848
arbitrary_types_allowed = True
4949

50+
@deprecated_argument("dtype", "vectorizer")
5051
def __init__(
5152
self,
5253
name: str,
@@ -72,9 +73,17 @@ def __init__(
7273
connection_kwargs (Dict[str, Any]): The connection arguments
7374
for the redis client. Defaults to empty {}.
7475
"""
75-
# Set vectorizer default
76-
if vectorizer is None:
77-
dtype = kwargs.get("dtype")
76+
dtype = kwargs.get("dtype")
77+
78+
# Validate a provided vectorizer or set the default
79+
if vectorizer:
80+
if not isinstance(vectorizer, BaseVectorizer):
81+
raise TypeError("Must provide a valid redisvl.vectorizer class.")
82+
if dtype and vectorizer.dtype != dtype:
83+
raise ValueError(
84+
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
85+
)
86+
else:
7887
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
7988
vectorizer = HFTextVectorizer(**vectorizer_kwargs)
8089

@@ -87,10 +96,9 @@ def __init__(
8796
vectorizer=vectorizer,
8897
routing_config=routing_config,
8998
)
90-
self._initialize_index(
91-
redis_client, redis_url, overwrite, vectorizer.dtype, **connection_kwargs
92-
)
99+
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
93100

101+
@deprecated_argument("dtype")
94102
def _initialize_index(
95103
self,
96104
redis_client: Optional[Redis] = None,
@@ -101,7 +109,7 @@ def _initialize_index(
101109
):
102110
"""Initialize the search index and handle Redis connection."""
103111
schema = SemanticRouterIndexSchema.from_params(
104-
self.name, self.vectorizer.dims, dtype
112+
self.name, self.vectorizer.dims, self.vectorizer.dtype
105113
)
106114
self._index = SearchIndex(schema=schema)
107115

@@ -170,9 +178,7 @@ def _add_routes(self, routes: List[Route]):
170178
for route in routes:
171179
# embed route references as a single batch
172180
reference_vectors = self.vectorizer.embed_many(
173-
[reference for reference in route.references],
174-
as_buffer=True,
175-
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
181+
[reference for reference in route.references], as_buffer=True
176182
)
177183
# set route references
178184
for i, reference in enumerate(route.references):
@@ -249,7 +255,6 @@ def _classify_route(
249255
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
250256
distance_threshold=distance_threshold,
251257
return_fields=["route_name"],
252-
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
253258
)
254259

255260
aggregate_request = self._build_aggregate_request(
@@ -302,7 +307,6 @@ def _classify_multi_route(
302307
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
303308
distance_threshold=distance_threshold,
304309
return_fields=["route_name"],
305-
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
306310
)
307311
aggregate_request = self._build_aggregate_request(
308312
vector_range_query, aggregation_method, max_k

redisvl/utils/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from functools import wraps
12
import json
23
from enum import Enum
34
from time import time
4-
from typing import Any, Dict
5+
from typing import Any, Callable, Dict, Optional
56
from uuid import uuid4
7+
from warnings import warn
68

79
from pydantic.v1 import BaseModel
810

@@ -57,3 +59,32 @@ def serialize(data: Dict[str, Any]) -> str:
5759
def deserialize(data: str) -> Dict[str, Any]:
5860
"""Deserialize the input from a string."""
5961
return json.loads(data)
62+
63+
64+
def deprecated_argument(argument: str, replacement: Optional[str] = None) -> Callable:
65+
"""
66+
Decorator to warn if a deprecated argument is passed.
67+
68+
When the wrapped function is called, the decorator will warn if the
69+
deprecated argument is passed as an argument or keyword argument.
70+
"""
71+
72+
message = f"Argument {argument} is deprecated and will be removed in the next major release."
73+
if replacement:
74+
message += f" Use {replacement} instead."
75+
76+
def wrapper(func):
77+
@wraps(func)
78+
def inner(*args, **kwargs):
79+
argument_names = func.__code__.co_varnames
80+
81+
if argument in argument_names:
82+
warn(message, DeprecationWarning, stacklevel=2)
83+
elif argument in kwargs:
84+
warn(message, DeprecationWarning, stacklevel=2)
85+
86+
return func(*args, **kwargs)
87+
88+
return inner
89+
90+
return wrapper

tests/integration/test_llmcache.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -885,11 +885,10 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
885885
name="float64_cache", dtype="float16", redis_url=redis_url
886886
)
887887

888-
889888
def test_vectorizer_dtype_mismatch():
890889
with pytest.raises(ValueError):
891890
SemanticCache(
892-
name="test_cache",
891+
name="test_dtype_mismatch",
893892
dtype="float32",
894893
vectorizer=HFTextVectorizer(dtype="float16"),
895894
)
@@ -898,12 +897,17 @@ def test_vectorizer_dtype_mismatch():
898897
def test_invalid_vectorizer():
899898
with pytest.raises(TypeError):
900899
SemanticCache(
901-
name="test_cache",
900+
name="test_invalid_vectorizer",
902901
vectorizer="invalid_vectorizer", # type: ignore
903902
)
904903

905904

906905
def test_passes_through_dtype_to_default_vectorizer():
907906
# The default is float32, so we should see float64 if we pass it in.
908-
cache = SemanticCache(name="test_cache", dtype="float64")
907+
cache = SemanticCache(name="test_pass_through_dtype)", dtype="float64")
909908
assert cache._vectorizer.dtype == "float64"
909+
910+
911+
def test_deprecated_dtype_argument():
912+
with pytest.warns(DeprecationWarning):
913+
SemanticCache(name="test_deprecated_dtype", dtype="float32")

tests/unit/test_utils.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
22
import pytest
3-
from ml_dtypes import bfloat16
43

54
from redisvl.redis.utils import (
65
array_to_buffer,
76
buffer_to_array,
87
convert_bytes,
98
make_dict,
109
)
10+
from redisvl.utils.utils import deprecated_argument
1111

1212

1313
def test_even_number_of_elements():
@@ -146,3 +146,94 @@ def test_conversion_with_invalid_floats():
146146
array = [float("inf"), float("-inf"), float("nan")]
147147
result = array_to_buffer(array, "float16")
148148
assert len(result) > 0 # Simple check to ensure it returns anything
149+
150+
151+
class TestDeprecatedArgument:
152+
def test_deprecation_warning_text_with_replacement(self):
153+
@deprecated_argument("dtype", "vectorizer")
154+
def test_func(dtype=None, vectorizer=None):
155+
pass
156+
157+
with pytest.warns(DeprecationWarning) as record:
158+
test_func(dtype="float32")
159+
160+
assert len(record) == 1
161+
assert str(record[0].message) == (
162+
"Argument dtype is deprecated and will be removed"
163+
" in the next major release. Use vectorizer instead."
164+
)
165+
166+
def test_deprecation_warning_text_without_replacement(self):
167+
@deprecated_argument("dtype")
168+
def test_func(dtype=None):
169+
pass
170+
171+
with pytest.warns(DeprecationWarning) as record:
172+
test_func(dtype="float32")
173+
174+
assert len(record) == 1
175+
assert str(record[0].message) == (
176+
"Argument dtype is deprecated and will be removed"
177+
" in the next major release."
178+
)
179+
180+
def test_function_argument(self):
181+
@deprecated_argument("dtype", "vectorizer")
182+
def test_func(dtype=None, vectorizer=None):
183+
pass
184+
185+
with pytest.warns(DeprecationWarning):
186+
test_func(dtype="float32")
187+
188+
def test_function_keyword_argument(self):
189+
@deprecated_argument("dtype", "vectorizer")
190+
def test_func(dtype=None, vectorizer=None):
191+
pass
192+
193+
with pytest.warns(DeprecationWarning):
194+
test_func(vectorizer="float32")
195+
196+
def test_class_method_argument(self):
197+
class TestClass:
198+
@deprecated_argument("dtype", "vectorizer")
199+
def test_method(self, dtype=None, vectorizer=None):
200+
pass
201+
202+
with pytest.warns(DeprecationWarning):
203+
TestClass().test_method(dtype="float32")
204+
205+
def test_class_method_keyword_argument(self):
206+
class TestClass:
207+
@deprecated_argument("dtype", "vectorizer")
208+
def test_method(self, dtype=None, vectorizer=None):
209+
pass
210+
211+
with pytest.warns(DeprecationWarning):
212+
TestClass().test_method(vectorizer="float32")
213+
214+
def test_class_init_argument(self):
215+
class TestClass:
216+
@deprecated_argument("dtype", "vectorizer")
217+
def __init__(self, dtype=None, vectorizer=None):
218+
pass
219+
220+
with pytest.warns(DeprecationWarning):
221+
TestClass(dtype="float32")
222+
223+
def test_class_init_keyword_argument(self):
224+
class TestClass:
225+
@deprecated_argument("dtype", "vectorizer")
226+
def __init__(self, dtype=None, vectorizer=None):
227+
pass
228+
229+
with pytest.warns(DeprecationWarning):
230+
TestClass(dtype="float32")
231+
232+
async def test_async_function_argument(self):
233+
@deprecated_argument("dtype", "vectorizer")
234+
async def test_func(dtype=None, vectorizer=None):
235+
return 1
236+
237+
with pytest.warns(DeprecationWarning):
238+
result = await test_func(dtype="float32")
239+
assert result == 1

0 commit comments

Comments
 (0)