Skip to content

Commit e23b652

Browse files
refactors MultiVectorQuery to accept Vector objects
1 parent 406f420 commit e23b652

File tree

4 files changed

+193
-203
lines changed

4 files changed

+193
-203
lines changed

redisvl/query/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from redisvl.query.aggregate import AggregationQuery, HybridQuery, MultiVectorQuery
1+
from redisvl.query.aggregate import (
2+
AggregationQuery,
3+
HybridQuery,
4+
MultiVectorQuery,
5+
Vector,
6+
)
27
from redisvl.query.query import (
38
BaseQuery,
49
BaseVectorQuery,
@@ -22,4 +27,5 @@
2227
"AggregationQuery",
2328
"HybridQuery",
2429
"MultiVectorQuery",
30+
"Vector",
2531
]

redisvl/query/aggregate.py

Lines changed: 60 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
11
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

3+
from pydantic import BaseModel, field_validator
34
from redis.commands.search.aggregation import AggregateRequest, Desc
45

56
from redisvl.query.filter import FilterExpression
67
from redisvl.redis.utils import array_to_buffer
8+
from redisvl.schema.fields import VectorDataType
79
from redisvl.utils.token_escaper import TokenEscaper
810
from redisvl.utils.utils import lazy_import
911

1012
nltk = lazy_import("nltk")
1113
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
1214

1315

16+
class Vector(BaseModel):
17+
"""
18+
Simple object containing the necessary arguments to perform a multi vector query.
19+
"""
20+
21+
vector: Union[List[float], bytes]
22+
field_name: str
23+
dtype: str = "float32"
24+
weight: float = 1.0
25+
26+
@field_validator("dtype")
27+
@classmethod
28+
def validate_dtype(cls, dtype: str) -> str:
29+
try:
30+
VectorDataType(dtype.upper())
31+
except ValueError:
32+
raise ValueError(
33+
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
34+
)
35+
36+
return dtype
37+
38+
1439
class AggregationQuery(AggregateRequest):
1540
"""
1641
Base class for aggregation queries used to create aggregation queries for Redis.
@@ -241,17 +266,33 @@ class MultiVectorQuery(AggregationQuery):
241266
242267
.. code-block:: python
243268
244-
from redisvl.query import MultiVectorQuery
269+
from redisvl.query import MultiVectorQuery, Vector
245270
from redisvl.index import SearchIndex
246271
247272
index = SearchIndex.from_yaml("path/to/index.yaml")
248273
274+
vector_1 = Vector(
275+
vector=[0.1, 0.2, 0.3],
276+
field_name="text_vector",
277+
dtype="float32",
278+
weight=0.7,
279+
)
280+
vector_2 = Vector(
281+
vector=[0.5, 0.5],
282+
field_name="image_vector",
283+
dtype="bfloat16",
284+
weight=0.2,
285+
)
286+
vector_3 = Vector(
287+
vector=[0.1, 0.2, 0.3],
288+
field_name="text_vector",
289+
dtype="float64",
290+
weight=0.5,
291+
)
292+
249293
query = MultiVectorQuery(
250-
vectors=[[0.1, 0.2, 0.3], [0.5, 0.5], [0.1, 0.1, 0.1, 0.1]],
251-
vector_field_names=["text_vector", "image_vector", "feature_vector"]
294+
vectors=[vector_1, vector_2, vector_3],
252295
filter_expression=None,
253-
weights=[0.7, 0.2, 0.5],
254-
dtypes=["float32", "bfloat16", "float64"],
255296
num_results=10,
256297
return_fields=["field1", "field2"],
257298
dialect=2,
@@ -260,14 +301,13 @@ class MultiVectorQuery(AggregationQuery):
260301
results = index.query(query)
261302
"""
262303

304+
_vectors: List[Vector]
305+
263306
def __init__(
264307
self,
265-
vectors: Union[bytes, List[bytes], List[float], List[List[float]]],
266-
vector_field_names: Union[str, List[str]],
267-
weights: List[float] = [1.0],
308+
vectors: Union[Vector, List[Vector]],
268309
return_fields: Optional[List[str]] = None,
269310
filter_expression: Optional[Union[str, FilterExpression]] = None,
270-
dtypes: List[str] = ["float32"],
271311
num_results: int = 10,
272312
return_score: bool = False,
273313
dialect: int = 2,
@@ -276,87 +316,39 @@ def __init__(
276316
Instantiates a MultiVectorQuery object.
277317
278318
Args:
279-
vectors (Union[bytes, List[bytes], List[float], List[List[float]]): The vectors to perform vector similarity search.
280-
vector_field_names (Union[str, List[str]]): The vector field names to search in.
281-
weights (List[float]): The weights of the vector similarity.
282-
Documents will be scored as:
283-
score = (w1) * score1 + (w2) * score2 + (w3) * score3 + ...
284-
Defaults to [1.0], which corresponds to equal weighting
319+
vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
285320
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
286321
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
287322
Defaults to None.
288-
dtypes (List[str]): The data types of the vectors. Defaults to ["float32"] for all vectors.
289323
num_results (int, optional): The number of results to return. Defaults to 10.
290324
return_score (bool): Whether to return the combined vector similarity score.
291325
Defaults to False.
292326
dialect (int, optional): The Redis dialect version. Defaults to 2.
293-
294-
Raises:
295-
ValueError: The number of vectors, vector field names, and weights do not agree.
296327
"""
297328

298329
self._filter_expression = filter_expression
299-
self._dtypes = dtypes
300330
self._num_results = num_results
301331

302-
if any([len(x) == 0 for x in [vectors, vector_field_names, weights, dtypes]]):
303-
raise ValueError(
304-
f"""The number of vectors and vector field names must be equal.
305-
If weights or dtypes are specified their number must match the number of vectors and vector field names also.
306-
Length of vectors list: {len(vectors) = }
307-
Length of vector_field_names list: {len(vector_field_names) = }
308-
Length of weights list: {len(weights) = }
309-
length of dtypes list: {len(dtypes) = }
310-
"""
311-
)
312-
313-
if isinstance(vectors, bytes) or isinstance(vectors[0], float):
332+
if isinstance(vectors, Vector):
314333
self._vectors = [vectors]
315334
else:
316335
self._vectors = vectors # type: ignore
317336

318-
if isinstance(vector_field_names, str):
319-
self._vector_field_names = [vector_field_names]
320-
else:
321-
self._vector_field_names = vector_field_names
322-
323-
if len(weights) == 1:
324-
self._weights = weights * len(vectors)
325-
else:
326-
self._weights = weights
327-
328-
if len(dtypes) == 1:
329-
self._dtypes = dtypes * len(vectors)
330-
else:
331-
self._dtypes = dtypes
332-
333-
num_vectors = len(self._vectors)
334-
if any(
335-
[
336-
len(x) != num_vectors # type: ignore
337-
for x in [self._vector_field_names, self._weights, self._dtypes]
338-
]
339-
):
340-
raise ValueError(
341-
f"""The number of vectors and vector field names must be equal.
342-
If weights or dtypes are specified their number must match the number of vectors and vector field names also.
343-
Length of vectors list: {len(self._vectors) = }
344-
Length of vector_field_names list: {len(self._vector_field_names) = }
345-
Length of weights list: {len(self._weights) = }
346-
Length of dtypes list: {len(self._dtypes) = }
347-
"""
337+
if not all([isinstance(v, Vector) for v in self._vectors]):
338+
raise TypeError(
339+
"vector arugment must be a Vector object or list of Vector objects."
348340
)
349341

350342
query_string = self._build_query_string()
351343
super().__init__(query_string)
352344

353345
# calculate the respective vector similarities
354-
for i in range(len(vectors)):
346+
for i in range(len(self._vectors)):
355347
self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"})
356348

357349
# construct the scoring string based on the vector similarity scores and weights
358350
combined_scores = []
359-
for i, w in enumerate(self._weights):
351+
for i, w in enumerate([v.weight for v in self._vectors]):
360352
combined_scores.append(f"@score_{i} * {w}")
361353
combined_score_string = " + ".join(combined_scores)
362354

@@ -375,7 +367,9 @@ def params(self) -> Dict[str, Any]:
375367
Dict[str, Any]: The parameters for the aggregation.
376368
"""
377369
params = {}
378-
for i, (vector, dtype) in enumerate(zip(self._vectors, self._dtypes)):
370+
for i, (vector, dtype) in enumerate(
371+
[(v.vector, v.dtype) for v in self._vectors]
372+
):
379373
if isinstance(vector, list):
380374
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
381375
params[f"vector_{i}"] = vector
@@ -387,7 +381,7 @@ def _build_query_string(self) -> str:
387381
# base KNN query
388382
range_queries = []
389383
for i, (vector, field) in enumerate(
390-
zip(self._vectors, self._vector_field_names)
384+
[(v.vector, v.field_name) for v in self._vectors]
391385
):
392386
range_queries.append(
393387
f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"

0 commit comments

Comments
 (0)