Skip to content

Commit 7275f70

Browse files
working multivector query class and tests
1 parent d0dff0b commit 7275f70

File tree

5 files changed

+559
-133
lines changed

5 files changed

+559
-133
lines changed

redisvl/query/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from redisvl.query.aggregate import AggregationQuery, HybridQuery
1+
from redisvl.query.aggregate import AggregationQuery, HybridQuery, MultiVectorQuery
22
from redisvl.query.query import (
33
BaseQuery,
44
BaseVectorQuery,
@@ -21,4 +21,5 @@
2121
"TextQuery",
2222
"AggregationQuery",
2323
"HybridQuery",
24+
"MultiVectorQuery",
2425
]

redisvl/query/aggregate.py

Lines changed: 63 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -231,76 +231,35 @@ def __str__(self) -> str:
231231

232232
class MultiVectorQuery(AggregationQuery):
233233
"""
234-
MultiVectorQuery allows for search over multiple vector fields in a document simulateously.
235-
The final score will be a weighted combination of the individual vector similarity scores
236-
following the formula:
234+
MultiVectorQuery allows for search over multiple vector fields in a document simulateously.
235+
The final score will be a weighted combination of the individual vector similarity scores
236+
following the formula:
237237
238-
score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... ) / (w_1 + w_2 + w_3 + ...)
239-
240-
Vectors may be of different size and datatype.
241-
242-
.. code-block:: python
243-
244-
from redisvl.query import MultiVectorQuery
245-
from redisvl.index import SearchIndex
246-
247-
index = SearchIndex.from_yaml("path/to/index.yaml")
248-
249-
query = MultiVectorQuery(
250-
vectors=[[0.1, 0.2, 0.3], [0.5, 0.5], [0.1, 0.1, 0.1, 0.1]],
251-
vector_field_names=["text_vector", "image_vector", "feature_vector"]
252-
filter_expression=None,
253-
weights=[0.7],
254-
dtypes=["float32", "float32", "float32"],
255-
num_results=10,
256-
return_fields=["field1", "field2"],
257-
dialect=2,
258-
)
259-
260-
results = index.query(query)
261-
262-
263-
264-
FT.AGGREGATE multi_vector_test
265-
"@user_embedding:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0}
266-
| @image_embedding:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}"
267-
PARAMS 4
268-
vector_0 "\xcd\xcc\xcc=\xcd\xcc\xcc=\x00\x00\x00?"
269-
vector_1 "\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?"
270-
APPLY "(2 - @distance_0)/2" AS score_0
271-
APPLY "(2 - @distance_1)/2" AS score_1
272-
DIALECT 2
273-
APPLY "(@score_0 + @score_1)" AS combined_score
274-
SORTBY 2 @combined_score
275-
ASC
276-
MAX 10
277-
LOAD 2 score_0 score_1
238+
score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )
278239
240+
Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric.
279241
242+
.. code-block:: python
280243
244+
from redisvl.query import MultiVectorQuery
245+
from redisvl.index import SearchIndex
281246
247+
index = SearchIndex.from_yaml("path/to/index.yaml")
282248
283-
FT.AGGREGATE 'idx:characters'
284-
"@embedding1:[VECTOR_RANGE .7 $vector1]=>{$YIELD_DISTANCE_AS: vector_distance1}
285-
| @embedding2:[VECTOR_RANGE 1.0 $vector2]=>{$YIELD_DISTANCE_AS: vector_distance2}
286-
| @embedding3:[VECTOR_RANGE 1.7 $vector3]=>{$YIELD_DISTANCE_AS: vector_distance3}
287-
| @name:(James)"
288-
### ADDSCORES
289-
### SCORER BM25STD.NORM
290-
### LOAD 2 created_at @embedding
291-
APPLY '(2 - @vector_distance1)/2' as v1
292-
APPLY '(2 - @vector_distance2)/2' as v2
293-
APPLY '(2 - @vector_distance3)/2' as v3
294-
APPLY '(@__score * 0.3 + (@v1 * 0.3) + (@v2 * 1.2) + (@v3 * 0.1))' AS final_score
295-
PARAMS 6 vector1 "\xe4\xd6..." vector2 "\x89\xa0..." vector3 "\x3c\x19..."
296-
SORTBY 2 @final_score DESC
297-
DIALECT 2
298-
LIMIT 0 100
249+
query = MultiVectorQuery(
250+
vectors=[[0.1, 0.2, 0.3], [0.5, 0.5], [0.1, 0.1, 0.1, 0.1]],
251+
vector_field_names=["text_vector", "image_vector", "feature_vector"]
252+
filter_expression=None,
253+
weights=[0.7, 0.2, 0.5],
254+
dtypes=["float32", "bfloat16", "float64"],
255+
num_results=10,
256+
return_fields=["field1", "field2"],
257+
dialect=2,
258+
)
299259
260+
results = index.query(query)
300261
"""
301262

302-
DISTANCE_ID: str = "vector_distance"
303-
304263
def __init__(
305264
self,
306265
vectors: Union[bytes, List[bytes], List[float], List[List[float]]],
@@ -340,58 +299,69 @@ def __init__(
340299
self._dtypes = dtypes
341300
self._num_results = num_results
342301

343-
if len(vectors) == 0 or len(vector_field_names) == 0 or len(weights) == 0:
302+
if any([len(x) == 0 for x in [vectors, vector_field_names, weights, dtypes]]):
344303
raise ValueError(
345304
f"""The number of vectors and vector field names must be equal.
346-
If weights are specified their number must match the number of vectors and vector field names also.
347-
Length of vectors list: {len(vectors) = }
348-
Length of vector_field_names list: {len(vector_field_names) = }
349-
Length of weights list: {len(weights) = }
350-
"""
305+
If weights or dtypes are specified their number must match the number of vectors and vector field names also.
306+
Length of vectors list: {len(vectors) = }
307+
Length of vector_field_names list: {len(vector_field_names) = }
308+
Length of weights list: {len(weights) = }
309+
length of dtypes list: {len(dtypes) = }
310+
"""
351311
)
352312

353313
if isinstance(vectors, bytes) or isinstance(vectors[0], float):
354314
self._vectors = [vectors]
355315
else:
356-
self._vectors = vectors
316+
self._vectors = vectors # type: ignore
317+
357318
if isinstance(vector_field_names, str):
358319
self._vector_field_names = [vector_field_names]
359320
else:
360321
self._vector_field_names = vector_field_names
322+
361323
if len(weights) == 1:
362324
self._weights = weights * len(vectors)
363325
else:
364326
self._weights = weights
327+
365328
if len(dtypes) == 1:
366329
self._dtypes = dtypes * len(vectors)
367330
else:
368331
self._dtypes = dtypes
369332

370-
if (len(self._vectors) != len(self._vector_field_names)) or (
371-
len(self._vectors) != len(self._weights)
333+
num_vectors = len(self._vectors)
334+
if any(
335+
[
336+
len(x) != num_vectors # type: ignore
337+
for x in [self._vector_field_names, self._weights, self._dtypes]
338+
]
372339
):
373340
raise ValueError(
374341
f"""The number of vectors and vector field names must be equal.
375-
If weights are specified their number must match the number of vectors and vector field names also.
376-
Length of vectors list: {len(self._vectors) = }
377-
Length of vector_field_names list: {len(self._vector_field_names) = }
378-
Length of weights list: {len(self._weights) = }
379-
"""
342+
If weights or dtypes are specified their number must match the number of vectors and vector field names also.
343+
Length of vectors list: {len(self._vectors) = }
344+
Length of vector_field_names list: {len(self._vector_field_names) = }
345+
Length of weights list: {len(self._weights) = }
346+
Length of dtypes list: {len(self._dtypes) = }
347+
"""
380348
)
381349

382350
query_string = self._build_query_string()
383351
super().__init__(query_string)
384352

353+
# calculate the respective vector similarities
354+
for i in range(len(vectors)):
355+
self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"})
356+
385357
# construct the scoring string based on the vector similarity scores and weights
386358
combined_scores = []
387359
for i, w in enumerate(self._weights):
388360
combined_scores.append(f"@score_{i} * {w}")
389361
combined_score_string = " + ".join(combined_scores)
390-
combined_score_string = f"'({combined_score_string})'"
391362

392363
self.apply(combined_score=combined_score_string)
393364

394-
# self.add_scores()
395365
self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore
396366
self.dialect(dialect)
397367
if return_fields:
@@ -405,43 +375,34 @@ def params(self) -> Dict[str, Any]:
405375
Dict[str, Any]: The parameters for the aggregation.
406376
"""
407377
params = {}
408-
for i, (vector, vector_field, dtype) in enumerate(zip(
409-
self._vectors, self._vector_field_names, self._dtypes
410-
)):
378+
for i, (vector, dtype) in enumerate(zip(self._vectors, self._dtypes)):
411379
if isinstance(vector, list):
412-
vector = array_to_buffer(vector, dtype=dtype)
380+
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
413381
params[f"vector_{i}"] = vector
414382
return params
415383

416384
def _build_query_string(self) -> str:
417385
"""Build the full query string for text search with optional filtering."""
418386

419-
filter_expression = self._filter_expression
420-
if isinstance(self._filter_expression, FilterExpression):
421-
filter_expression = str(self._filter_expression)
422-
423387
# base KNN query
424-
knn_queries = []
425388
range_queries = []
426-
for i, (vector, field) in enumerate(zip(self._vectors, self._vector_field_names)):
427-
knn_queries.append(f"[KNN {self._num_results} @{field} $vector_{i} AS distance_{i}]")
428-
range_queries.append(f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}")
389+
for i, (vector, field) in enumerate(
390+
zip(self._vectors, self._vector_field_names)
391+
):
392+
range_queries.append(
393+
f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"
394+
)
429395

430-
knn_query = " | ".join(knn_queries) ## knn_queries format doesn't work
431-
knn_query = " | ".join(range_queries)
396+
range_query = " | ".join(range_queries)
432397

433-
# calculate the respective vector similarities
434-
apply_string = ""
435-
for i, (vector, field_name, weight) in enumerate(
436-
zip(self._vectors, self._vector_field_names, self._weights)
437-
):
438-
apply_string += f'APPLY "(2 - @distance_{i})/2" AS score_{i} '
398+
filter_expression = self._filter_expression
399+
if isinstance(self._filter_expression, FilterExpression):
400+
filter_expression = str(self._filter_expression)
439401

440-
return (
441-
f"{knn_query} {filter_expression} {apply_string}"
442-
if filter_expression
443-
else f"{knn_query} {apply_string}"
444-
)
402+
if filter_expression:
403+
return f"({range_query}) AND ({filter_expression})"
404+
else:
405+
return f"{range_query}"
445406

446407
def __str__(self) -> str:
447408
"""Return the string representation of the query."""

0 commit comments

Comments
 (0)