11from typing import Any , Dict , List , Optional , Set , Tuple , Union
22
3+ from pydantic import BaseModel , field_validator
34from redis .commands .search .aggregation import AggregateRequest , Desc
45
56from redisvl .query .filter import FilterExpression
67from redisvl .redis .utils import array_to_buffer
8+ from redisvl .schema .fields import VectorDataType
79from redisvl .utils .token_escaper import TokenEscaper
810from redisvl .utils .utils import lazy_import
911
1012nltk = lazy_import ("nltk" )
1113nltk_stopwords = lazy_import ("nltk.corpus.stopwords" )
1214
1315
16+ class Vector (BaseModel ):
17+ """
18+ Simple object containing the necessary arguments to perform a multi vector query.
19+ """
20+
21+ vector : Union [List [float ], bytes ]
22+ field_name : str
23+ dtype : str = "float32"
24+ weight : float = 1.0
25+
26+ @field_validator ("dtype" )
27+ @classmethod
28+ def validate_dtype (cls , dtype : str ) -> str :
29+ try :
30+ VectorDataType (dtype .upper ())
31+ except ValueError :
32+ raise ValueError (
33+ f"Invalid data type: { dtype } . Supported types are: { [t .lower () for t in VectorDataType ]} "
34+ )
35+
36+ return dtype
37+
38+
1439class AggregationQuery (AggregateRequest ):
1540 """
1641 Base class for aggregation queries used to create aggregation queries for Redis.
@@ -241,17 +266,33 @@ class MultiVectorQuery(AggregationQuery):
241266
242267 .. code-block:: python
243268
244- from redisvl.query import MultiVectorQuery
269+ from redisvl.query import MultiVectorQuery, Vector
245270 from redisvl.index import SearchIndex
246271
247272 index = SearchIndex.from_yaml("path/to/index.yaml")
248273
274+ vector_1 = Vector(
275+ vector=[0.1, 0.2, 0.3],
276+ field_name="text_vector",
277+ dtype="float32",
278+ weight=0.7,
279+ )
280+ vector_2 = Vector(
281+ vector=[0.5, 0.5],
282+ field_name="image_vector",
283+ dtype="bfloat16",
284+ weight=0.2,
285+ )
286+ vector_3 = Vector(
287+ vector=[0.1, 0.2, 0.3],
288+ field_name="text_vector",
289+ dtype="float64",
290+ weight=0.5,
291+ )
292+
249293 query = MultiVectorQuery(
250- vectors=[[0.1, 0.2, 0.3], [0.5, 0.5], [0.1, 0.1, 0.1, 0.1]],
251- vector_field_names=["text_vector", "image_vector", "feature_vector"]
294+ vectors=[vector_1, vector_2, vector_3],
252295 filter_expression=None,
253- weights=[0.7, 0.2, 0.5],
254- dtypes=["float32", "bfloat16", "float64"],
255296 num_results=10,
256297 return_fields=["field1", "field2"],
257298 dialect=2,
@@ -260,14 +301,13 @@ class MultiVectorQuery(AggregationQuery):
260301 results = index.query(query)
261302 """
262303
304+ _vectors : List [Vector ]
305+
263306 def __init__ (
264307 self ,
265- vectors : Union [bytes , List [bytes ], List [float ], List [List [float ]]],
266- vector_field_names : Union [str , List [str ]],
267- weights : List [float ] = [1.0 ],
308+ vectors : Union [Vector , List [Vector ]],
268309 return_fields : Optional [List [str ]] = None ,
269310 filter_expression : Optional [Union [str , FilterExpression ]] = None ,
270- dtypes : List [str ] = ["float32" ],
271311 num_results : int = 10 ,
272312 return_score : bool = False ,
273313 dialect : int = 2 ,
@@ -276,87 +316,39 @@ def __init__(
276316 Instantiates a MultiVectorQuery object.
277317
278318 Args:
279- vectors (Union[bytes, List[bytes], List[float], List[List[float]]): The vectors to perform vector similarity search.
280- vector_field_names (Union[str, List[str]]): The vector field names to search in.
281- weights (List[float]): The weights of the vector similarity.
282- Documents will be scored as:
283- score = (w1) * score1 + (w2) * score2 + (w3) * score3 + ...
284- Defaults to [1.0], which corresponds to equal weighting
319+ vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
285320 return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
286321 filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
287322 Defaults to None.
288- dtypes (List[str]): The data types of the vectors. Defaults to ["float32"] for all vectors.
289323 num_results (int, optional): The number of results to return. Defaults to 10.
290324 return_score (bool): Whether to return the combined vector similarity score.
291325 Defaults to False.
292326 dialect (int, optional): The Redis dialect version. Defaults to 2.
293-
294- Raises:
295- ValueError: The number of vectors, vector field names, and weights do not agree.
296327 """
297328
298329 self ._filter_expression = filter_expression
299- self ._dtypes = dtypes
300330 self ._num_results = num_results
301331
302- if any ([len (x ) == 0 for x in [vectors , vector_field_names , weights , dtypes ]]):
303- raise ValueError (
304- f"""The number of vectors and vector field names must be equal.
305- If weights or dtypes are specified their number must match the number of vectors and vector field names also.
306- Length of vectors list: { len (vectors ) = }
307- Length of vector_field_names list: { len (vector_field_names ) = }
308- Length of weights list: { len (weights ) = }
309- length of dtypes list: { len (dtypes ) = }
310- """
311- )
312-
313- if isinstance (vectors , bytes ) or isinstance (vectors [0 ], float ):
332+ if isinstance (vectors , Vector ):
314333 self ._vectors = [vectors ]
315334 else :
316335 self ._vectors = vectors # type: ignore
317336
318- if isinstance (vector_field_names , str ):
319- self ._vector_field_names = [vector_field_names ]
320- else :
321- self ._vector_field_names = vector_field_names
322-
323- if len (weights ) == 1 :
324- self ._weights = weights * len (vectors )
325- else :
326- self ._weights = weights
327-
328- if len (dtypes ) == 1 :
329- self ._dtypes = dtypes * len (vectors )
330- else :
331- self ._dtypes = dtypes
332-
333- num_vectors = len (self ._vectors )
334- if any (
335- [
336- len (x ) != num_vectors # type: ignore
337- for x in [self ._vector_field_names , self ._weights , self ._dtypes ]
338- ]
339- ):
340- raise ValueError (
341- f"""The number of vectors and vector field names must be equal.
342- If weights or dtypes are specified their number must match the number of vectors and vector field names also.
343- Length of vectors list: { len (self ._vectors ) = }
344- Length of vector_field_names list: { len (self ._vector_field_names ) = }
345- Length of weights list: { len (self ._weights ) = }
346- Length of dtypes list: { len (self ._dtypes ) = }
347- """
337+ if not all ([isinstance (v , Vector ) for v in self ._vectors ]):
338+ raise TypeError (
339+ "vector arugment must be a Vector object or list of Vector objects."
348340 )
349341
350342 query_string = self ._build_query_string ()
351343 super ().__init__ (query_string )
352344
353345 # calculate the respective vector similarities
354- for i in range (len (vectors )):
346+ for i in range (len (self . _vectors )):
355347 self .apply (** {f"score_{ i } " : f"(2 - @distance_{ i } )/2" })
356348
357349 # construct the scoring string based on the vector similarity scores and weights
358350 combined_scores = []
359- for i , w in enumerate (self ._weights ):
351+ for i , w in enumerate ([ v . weight for v in self ._vectors ] ):
360352 combined_scores .append (f"@score_{ i } * { w } " )
361353 combined_score_string = " + " .join (combined_scores )
362354
@@ -375,7 +367,9 @@ def params(self) -> Dict[str, Any]:
375367 Dict[str, Any]: The parameters for the aggregation.
376368 """
377369 params = {}
378- for i , (vector , dtype ) in enumerate (zip (self ._vectors , self ._dtypes )):
370+ for i , (vector , dtype ) in enumerate (
371+ [(v .vector , v .dtype ) for v in self ._vectors ]
372+ ):
379373 if isinstance (vector , list ):
380374 vector = array_to_buffer (vector , dtype = dtype ) # type: ignore
381375 params [f"vector_{ i } " ] = vector
@@ -387,7 +381,7 @@ def _build_query_string(self) -> str:
387381 # base KNN query
388382 range_queries = []
389383 for i , (vector , field ) in enumerate (
390- zip ( self . _vectors , self . _vector_field_names )
384+ [( v . vector , v . field_name ) for v in self . _vectors ]
391385 ):
392386 range_queries .append (
393387 f"@{ field } :[VECTOR_RANGE 2.0 $vector_{ i } ]=>{{$YIELD_DISTANCE_AS: distance_{ i } }}"
0 commit comments