2020from redisvl .query import RangeQuery
2121from redisvl .redis .utils import convert_bytes , hashify , make_dict
2222from redisvl .utils .log import get_logger
23- from redisvl .utils .utils import model_to_dict
23+ from redisvl .utils .utils import deprecated_argument , model_to_dict
2424from redisvl .utils .vectorize import (
2525 BaseVectorizer ,
2626 HFTextVectorizer ,
@@ -47,6 +47,7 @@ class SemanticRouter(BaseModel):
4747 class Config :
4848 arbitrary_types_allowed = True
4949
50+ @deprecated_argument ("dtype" , "vectorizer" )
5051 def __init__ (
5152 self ,
5253 name : str ,
@@ -72,9 +73,17 @@ def __init__(
7273 connection_kwargs (Dict[str, Any]): The connection arguments
7374 for the redis client. Defaults to empty {}.
7475 """
75- # Set vectorizer default
76- if vectorizer is None :
77- dtype = kwargs .get ("dtype" )
76+ dtype = kwargs .get ("dtype" )
77+
78+ # Validate a provided vectorizer or set the default
79+ if vectorizer :
80+ if not isinstance (vectorizer , BaseVectorizer ):
81+ raise TypeError ("Must provide a valid redisvl.vectorizer class." )
82+ if dtype and vectorizer .dtype != dtype :
83+ raise ValueError (
84+ f"Provided dtype { dtype } does not match vectorizer dtype { vectorizer .dtype } "
85+ )
86+ else :
7887 vectorizer_kwargs = {"dtype" : dtype } if dtype else {}
7988 vectorizer = HFTextVectorizer (** vectorizer_kwargs )
8089
@@ -87,10 +96,9 @@ def __init__(
8796 vectorizer = vectorizer ,
8897 routing_config = routing_config ,
8998 )
90- self ._initialize_index (
91- redis_client , redis_url , overwrite , vectorizer .dtype , ** connection_kwargs
92- )
99+ self ._initialize_index (redis_client , redis_url , overwrite , ** connection_kwargs )
93100
101+ @deprecated_argument ("dtype" )
94102 def _initialize_index (
95103 self ,
96104 redis_client : Optional [Redis ] = None ,
@@ -101,7 +109,7 @@ def _initialize_index(
101109 ):
102110 """Initialize the search index and handle Redis connection."""
103111 schema = SemanticRouterIndexSchema .from_params (
104- self .name , self .vectorizer .dims , dtype
112+ self .name , self .vectorizer .dims , self . vectorizer . dtype
105113 )
106114 self ._index = SearchIndex (schema = schema )
107115
@@ -170,9 +178,7 @@ def _add_routes(self, routes: List[Route]):
170178 for route in routes :
171179 # embed route references as a single batch
172180 reference_vectors = self .vectorizer .embed_many (
173- [reference for reference in route .references ],
174- as_buffer = True ,
175- dtype = self ._index .schema .fields [ROUTE_VECTOR_FIELD_NAME ].attrs .datatype , # type: ignore[union-attr]
181+ [reference for reference in route .references ], as_buffer = True
176182 )
177183 # set route references
178184 for i , reference in enumerate (route .references ):
@@ -249,7 +255,6 @@ def _classify_route(
249255 vector_field_name = ROUTE_VECTOR_FIELD_NAME ,
250256 distance_threshold = distance_threshold ,
251257 return_fields = ["route_name" ],
252- dtype = self ._index .schema .fields [ROUTE_VECTOR_FIELD_NAME ].attrs .datatype , # type: ignore[union-attr]
253258 )
254259
255260 aggregate_request = self ._build_aggregate_request (
@@ -302,7 +307,6 @@ def _classify_multi_route(
302307 vector_field_name = ROUTE_VECTOR_FIELD_NAME ,
303308 distance_threshold = distance_threshold ,
304309 return_fields = ["route_name" ],
305- dtype = self ._index .schema .fields [ROUTE_VECTOR_FIELD_NAME ].attrs .datatype , # type: ignore[union-attr]
306310 )
307311 aggregate_request = self ._build_aggregate_request (
308312 vector_range_query , aggregation_method , max_k
0 commit comments