Skip to content

Commit e03ae4d

Browse files
renames HybridQuery to AggregateHybridQuery
1 parent 69407f6 commit e03ae4d

File tree

4 files changed

+156
-40
lines changed

4 files changed

+156
-40
lines changed

redisvl/query/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from redisvl.query.aggregate import (
2+
AggregateHybridQuery,
23
AggregationQuery,
34
HybridQuery,
45
MultiVectorQuery,
@@ -25,6 +26,7 @@
2526
"CountQuery",
2627
"TextQuery",
2728
"AggregationQuery",
29+
"AggregateHybridQuery",
2830
"HybridQuery",
2931
"MultiVectorQuery",
3032
"Vector",

redisvl/query/aggregate.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Any, Dict, List, Optional, Set, Tuple, Union
23

34
from pydantic import BaseModel, field_validator, model_validator
@@ -53,20 +54,20 @@ def __init__(self, query_string):
5354
super().__init__(query_string)
5455

5556

56-
class HybridQuery(AggregationQuery):
57+
class AggregateHybridQuery(AggregationQuery):
5758
"""
58-
HybridQuery combines text and vector search in Redis.
59+
AggregateHybridQuery combines text and vector search in Redis.
5960
It allows you to perform a hybrid search using both text and vector similarity.
6061
It scores documents based on a weighted combination of text and vector similarity.
6162
6263
.. code-block:: python
6364
64-
from redisvl.query import HybridQuery
65+
from redisvl.query import AggregateHybridQuery
6566
from redisvl.index import SearchIndex
6667
6768
index = SearchIndex.from_yaml("path/to/index.yaml")
6869
69-
query = HybridQuery(
70+
query = AggregateHybridQuery(
7071
text="example text",
7172
text_field_name="text_field",
7273
vector=[0.1, 0.2, 0.3],
@@ -105,7 +106,7 @@ def __init__(
105106
text_weights: Optional[Dict[str, float]] = None,
106107
):
107108
"""
108-
Instantiates a HybridQuery object.
109+
Instantiates a AggregateHybridQuery object.
109110
110111
Args:
111112
text (str): The text to search for.
@@ -313,6 +314,26 @@ def __str__(self) -> str:
313314
return " ".join([str(x) for x in self.build_args()])
314315

315316

317+
class HybridQuery(AggregateHybridQuery):
318+
"""Backward compatibility wrapper for AggregateHybridQuery.
319+
320+
.. deprecated::
321+
HybridQuery is a backward compatibility wrapper around AggregateHybridQuery
322+
and will eventually be replaced with a new hybrid query implementation.
323+
to maintain current functionality please use AggregateHybridQuery directly.",
324+
"""
325+
326+
def __init__(self, *args, **kwargs):
327+
warnings.warn(
328+
"HybridQuery is a backward compatibility wrapper around AggregateHybridQuery "
329+
"and will eventually be replaced with a new hybrid query implementation. "
330+
"to maintain current functionality please use AggregateHybridQuery directly.",
331+
DeprecationWarning,
332+
stacklevel=2,
333+
)
334+
super().__init__(*args, **kwargs)
335+
336+
316337
class MultiVectorQuery(AggregationQuery):
317338
"""
318339
MultiVectorQuery allows for search over multiple vector fields in a document simultaneously.

tests/integration/test_aggregation.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from redisvl.index import SearchIndex
4-
from redisvl.query import HybridQuery, MultiVectorQuery, Vector
4+
from redisvl.query import AggregateHybridQuery, HybridQuery, MultiVectorQuery, Vector
55
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
66
from redisvl.redis.utils import array_to_buffer
77
from tests.conftest import skip_if_redis_version_below
@@ -89,7 +89,7 @@ def test_hybrid_query(index):
8989
vector_field = "user_embedding"
9090
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
9191

92-
hybrid_query = HybridQuery(
92+
hybrid_query = AggregateHybridQuery(
9393
text=text,
9494
text_field_name=text_field,
9595
vector=vector,
@@ -115,7 +115,7 @@ def test_hybrid_query(index):
115115
assert doc["job"] in ["engineer", "doctor", "dermatologist", "CEO", "dentist"]
116116
assert doc["credit_score"] in ["high", "low", "medium"]
117117

118-
hybrid_query = HybridQuery(
118+
hybrid_query = AggregateHybridQuery(
119119
text=text,
120120
text_field_name=text_field,
121121
vector=vector,
@@ -141,7 +141,7 @@ def test_empty_query_string():
141141

142142
# test if text is empty
143143
with pytest.raises(ValueError):
144-
hybrid_query = HybridQuery(
144+
hybrid_query = AggregateHybridQuery(
145145
text=text,
146146
text_field_name=text_field,
147147
vector=vector,
@@ -151,7 +151,7 @@ def test_empty_query_string():
151151
# test if text becomes empty after stopwords are removed
152152
text = "with a for but and" # will all be removed as default stopwords
153153
with pytest.raises(ValueError):
154-
hybrid_query = HybridQuery(
154+
hybrid_query = AggregateHybridQuery(
155155
text=text,
156156
text_field_name=text_field,
157157
vector=vector,
@@ -169,7 +169,7 @@ def test_hybrid_query_with_filter(index):
169169
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
170170
filter_expression = (Tag("credit_score") == ("high")) & (Num("age") > 30)
171171

172-
hybrid_query = HybridQuery(
172+
hybrid_query = AggregateHybridQuery(
173173
text=text,
174174
text_field_name=text_field,
175175
vector=vector,
@@ -195,7 +195,7 @@ def test_hybrid_query_with_geo_filter(index):
195195
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
196196
filter_expression = Geo("location") == GeoRadius(-122.4194, 37.7749, 1000, "m")
197197

198-
hybrid_query = HybridQuery(
198+
hybrid_query = AggregateHybridQuery(
199199
text=text,
200200
text_field_name=text_field,
201201
vector=vector,
@@ -219,7 +219,7 @@ def test_hybrid_query_alpha(index, alpha):
219219
vector = [0.1, 0.1, 0.5]
220220
vector_field = "user_embedding"
221221

222-
hybrid_query = HybridQuery(
222+
hybrid_query = AggregateHybridQuery(
223223
text=text,
224224
text_field_name=text_field,
225225
vector=vector,
@@ -247,7 +247,7 @@ def test_hybrid_query_stopwords(index):
247247
vector_field = "user_embedding"
248248
alpha = 0.5
249249

250-
hybrid_query = HybridQuery(
250+
hybrid_query = AggregateHybridQuery(
251251
text=text,
252252
text_field_name=text_field,
253253
vector=vector,
@@ -282,7 +282,7 @@ def test_hybrid_query_with_text_filter(index):
282282
filter_expression = Text(text_field) == ("medical")
283283

284284
# make sure we can still apply filters to the same text field we are querying
285-
hybrid_query = HybridQuery(
285+
hybrid_query = AggregateHybridQuery(
286286
text=text,
287287
text_field_name=text_field,
288288
vector=vector,
@@ -300,7 +300,7 @@ def test_hybrid_query_with_text_filter(index):
300300
filter_expression = (Text(text_field) == ("medical")) & (
301301
(Text(text_field) != ("research"))
302302
)
303-
hybrid_query = HybridQuery(
303+
hybrid_query = AggregateHybridQuery(
304304
text=text,
305305
text_field_name=text_field,
306306
vector=vector,
@@ -330,7 +330,7 @@ def test_hybrid_query_word_weights(index, scorer):
330330
weights = {"medical": 3.4, "cancers": 5}
331331

332332
# test we can run a query with text weights
333-
weighted_query = HybridQuery(
333+
weighted_query = AggregateHybridQuery(
334334
text=text,
335335
text_field_name=text_field,
336336
vector=vector,
@@ -344,7 +344,7 @@ def test_hybrid_query_word_weights(index, scorer):
344344
assert len(weighted_results) == 7
345345

346346
# test that weights do change the scores on results
347-
unweighted_query = HybridQuery(
347+
unweighted_query = AggregateHybridQuery(
348348
text=text,
349349
text_field_name=text_field,
350350
vector=vector,
@@ -363,7 +363,7 @@ def test_hybrid_query_word_weights(index, scorer):
363363

364364
# test that weights do change the document score and order of results
365365
weights = {"medical": 5, "cancers": 3.4} # switch the weights
366-
weighted_query = HybridQuery(
366+
weighted_query = AggregateHybridQuery(
367367
text=text,
368368
text_field_name=text_field,
369369
vector=vector,
@@ -377,7 +377,7 @@ def test_hybrid_query_word_weights(index, scorer):
377377
assert weighted_results != unweighted_results
378378

379379
# test assigning weights on construction is equivalent to setting them on the query object
380-
new_query = HybridQuery(
380+
new_query = AggregateHybridQuery(
381381
text=text,
382382
text_field_name=text_field,
383383
vector=vector,
@@ -743,3 +743,44 @@ def test_multivector_query_mixed_index(index):
743743
assert (
744744
float(r["combined_score"]) - score <= 0.0001
745745
) # allow for small floating point error
746+
747+
748+
def test_hybrid_query_backward_compatibility(index):
749+
skip_if_redis_version_below(index.client, "7.2.0")
750+
751+
text = "a medical professional with expertise in lung cancer"
752+
text_field = "description"
753+
vector = [0.1, 0.1, 0.5]
754+
vector_field = "user_embedding"
755+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
756+
757+
hybrid_query = AggregateHybridQuery(
758+
text=text,
759+
text_field_name=text_field,
760+
vector=vector,
761+
vector_field_name=vector_field,
762+
return_fields=return_fields,
763+
)
764+
765+
results = index.query(hybrid_query)
766+
assert len(results) == 7
767+
for result in results:
768+
assert result["user"] in [
769+
"john",
770+
"derrick",
771+
"nancy",
772+
"tyler",
773+
"tim",
774+
"taimur",
775+
"joe",
776+
"mary",
777+
]
778+
779+
with pytest.warns(DeprecationWarning):
780+
_ = HybridQuery(
781+
text=text,
782+
text_field_name=text_field,
783+
vector=vector,
784+
vector_field_name=vector_field,
785+
return_fields=return_fields,
786+
)

0 commit comments

Comments
 (0)