File tree Expand file tree Collapse file tree 2 files changed +29
-1
lines changed
redisvl/extensions/llmcache Expand file tree Collapse file tree 2 files changed +29
-1
lines changed Original file line number Diff line number Diff line change @@ -86,12 +86,17 @@ def __init__(
8686 else :
8787 prefix = name
8888
89+ dtype = kwargs .get ("dtype" )
90+
8991 # Validate a provided vectorizer or set the default
9092 if vectorizer :
9193 if not isinstance (vectorizer , BaseVectorizer ):
9294 raise TypeError ("Must provide a valid redisvl.vectorizer class." )
95+ if dtype and vectorizer .dtype != dtype :
96+ raise ValueError (
97+ f"Provided dtype { dtype } does not match vectorizer dtype { vectorizer .dtype } "
98+ )
9399 else :
94- dtype = kwargs .get ("dtype" )
95100 vectorizer_kwargs = {"dtype" : dtype } if dtype else {}
96101
97102 vectorizer = HFTextVectorizer (
Original file line number Diff line number Diff line change @@ -884,3 +884,26 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
884884 bad_type = SemanticCache (
885885 name = "float64_cache" , dtype = "float16" , redis_url = redis_url
886886 )
887+
888+
889+ def test_vectorizer_dtype_mismatch ():
890+ with pytest .raises (ValueError ):
891+ SemanticCache (
892+ name = "test_cache" ,
893+ dtype = "float32" ,
894+ vectorizer = HFTextVectorizer (dtype = "float16" ),
895+ )
896+
897+
898+ def test_invalid_vectorizer ():
899+ with pytest .raises (TypeError ):
900+ SemanticCache (
901+ name = "test_cache" ,
902+ vectorizer = "invalid_vectorizer" , # type: ignore
903+ )
904+
905+
906+ def test_passes_through_dtype_to_default_vectorizer ():
907+ # The default is float32, so we should see float64 if we pass it in.
908+ cache = SemanticCache (name = "test_cache" , dtype = "float64" )
909+ assert cache ._vectorizer .dtype == "float64"
You can’t perform that action at this time.
0 commit comments