Skip to content

Commit 9865faa

Browse files
committed
Require dtype argument and vectorizer dtype to match
1 parent 123be06 commit 9865faa

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,17 @@ def __init__(
8686
else:
8787
prefix = name
8888

89+
dtype = kwargs.get("dtype")
90+
8991
# Validate a provided vectorizer or set the default
9092
if vectorizer:
9193
if not isinstance(vectorizer, BaseVectorizer):
9294
raise TypeError("Must provide a valid redisvl.vectorizer class.")
95+
if dtype and vectorizer.dtype != dtype:
96+
raise ValueError(
97+
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
98+
)
9399
else:
94-
dtype = kwargs.get("dtype")
95100
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
96101

97102
vectorizer = HFTextVectorizer(

tests/integration/test_llmcache.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,3 +884,26 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
884884
bad_type = SemanticCache(
885885
name="float64_cache", dtype="float16", redis_url=redis_url
886886
)
887+
888+
889+
def test_vectorizer_dtype_mismatch():
890+
with pytest.raises(ValueError):
891+
SemanticCache(
892+
name="test_cache",
893+
dtype="float32",
894+
vectorizer=HFTextVectorizer(dtype="float16"),
895+
)
896+
897+
898+
def test_invalid_vectorizer():
899+
with pytest.raises(TypeError):
900+
SemanticCache(
901+
name="test_cache",
902+
vectorizer="invalid_vectorizer", # type: ignore
903+
)
904+
905+
906+
def test_passes_through_dtype_to_default_vectorizer():
907+
# The default is float32, so we should see float64 if we pass it in.
908+
cache = SemanticCache(name="test_cache", dtype="float64")
909+
assert cache._vectorizer.dtype == "float64"

0 commit comments

Comments
 (0)