@@ -86,12 +86,28 @@ def __init__(
8686 else :
8787 prefix = name
8888
89- # Set vectorizer default
90- if vectorizer is None :
89+ # Validate a provided vectorizer or set the default
90+ if vectorizer :
91+ if not isinstance (vectorizer , BaseVectorizer ):
92+ raise TypeError ("Must provide a valid redisvl.vectorizer class." )
93+ else :
94+ dtype = kwargs .get ("dtype" )
95+ vectorizer_kwargs = {"dtype" : dtype } if dtype else {}
96+
9197 vectorizer = HFTextVectorizer (
92- model = "sentence-transformers/all-mpnet-base-v2"
98+ model = "sentence-transformers/all-mpnet-base-v2" ,
99+ ** vectorizer_kwargs ,
93100 )
94101
102+ self ._vectorizer = vectorizer
103+
104+ # Create semantic cache schema and index
105+ schema = SemanticCacheIndexSchema .from_params (
106+ name , prefix , vectorizer .dims , vectorizer .dtype
107+ )
108+ schema = self ._modify_schema (schema , filterable_fields )
109+ self ._index = SearchIndex (schema = schema )
110+
95111 # Process fields and other settings
96112 self .set_threshold (distance_threshold )
97113 self .return_fields = [
@@ -103,14 +119,6 @@ def __init__(
103119 METADATA_FIELD_NAME ,
104120 ]
105121
106- # Create semantic cache schema and index
107- dtype = kwargs .get ("dtype" , "float32" )
108- schema = SemanticCacheIndexSchema .from_params (
109- name , prefix , vectorizer .dims , dtype
110- )
111- schema = self ._modify_schema (schema , filterable_fields )
112- self ._index = SearchIndex (schema = schema )
113-
114122 # Handle redis connection
115123 if redis_client :
116124 self ._index .set_client (redis_client )
@@ -128,20 +136,9 @@ def __init__(
128136 "If you wish to overwrite the index schema, set overwrite=True during initialization."
129137 )
130138
131- # Create the search index
139+ # Create the search index in Redis
132140 self ._index .create (overwrite = overwrite , drop = False )
133141
134- # Initialize and validate vectorizer
135- if not isinstance (vectorizer , BaseVectorizer ):
136- raise TypeError ("Must provide a valid redisvl.vectorizer class." )
137-
138- validate_vector_dims (
139- vectorizer .dims ,
140- self ._index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .dims , # type: ignore
141- )
142- self ._vectorizer = vectorizer
143- self ._dtype = self .index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .datatype # type: ignore[union-attr]
144-
145142 def _modify_schema (
146143 self ,
147144 schema : SemanticCacheIndexSchema ,
@@ -290,7 +287,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
290287 if not isinstance (prompt , str ):
291288 raise TypeError ("Prompt must be a string." )
292289
293- return self ._vectorizer .embed (prompt , dtype = self . _dtype )
290+ return self ._vectorizer .embed (prompt )
294291
295292 async def _avectorize_prompt (self , prompt : Optional [str ]) -> List [float ]:
296293 """Converts a text prompt to its vector representation using the
@@ -372,7 +369,7 @@ def check(
372369 num_results = num_results ,
373370 return_score = True ,
374371 filter_expression = filter_expression ,
375- dtype = self ._dtype ,
372+ dtype = self ._vectorizer . dtype ,
376373 )
377374
378375 # Search the cache!
@@ -543,7 +540,7 @@ def store(
543540 # Load cache entry with TTL
544541 ttl = ttl or self ._ttl
545542 keys = self ._index .load (
546- data = [cache_entry .to_dict (self ._dtype )],
543+ data = [cache_entry .to_dict (self ._vectorizer . dtype )],
547544 ttl = ttl ,
548545 id_field = ENTRY_ID_FIELD_NAME ,
549546 )
@@ -607,7 +604,7 @@ async def astore(
607604 # Load cache entry with TTL
608605 ttl = ttl or self ._ttl
609606 keys = await aindex .load (
610- data = [cache_entry .to_dict (self ._dtype )],
607+ data = [cache_entry .to_dict (self ._vectorizer . dtype )],
611608 ttl = ttl ,
612609 id_field = ENTRY_ID_FIELD_NAME ,
613610 )
0 commit comments