33from redis .commands .search .result import Result
44
55from redisvl .index import SearchIndex
6- from redisvl .query import HybridAggregationQuery
6+ from redisvl .query import HybridQuery
77from redisvl .query .filter import FilterExpression , Geo , GeoRadius , Num , Tag , Text
88from redisvl .redis .connection import compare_versions
99from redisvl .redis .utils import array_to_buffer
@@ -70,15 +70,15 @@ def test_aggregation_query(index):
7070 vector_field = "user_embedding"
7171 return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
7272
73- hybrid_query = HybridAggregationQuery (
73+ hybrid_query = HybridQuery (
7474 text = text ,
7575 text_field_name = text_field ,
7676 vector = vector ,
7777 vector_field_name = vector_field ,
7878 return_fields = return_fields ,
7979 )
8080
81- results = index .aggregate_query (hybrid_query )
81+ results = index .query (hybrid_query )
8282 assert isinstance (results , list )
8383 assert len (results ) == 7
8484 for doc in results :
@@ -96,15 +96,15 @@ def test_aggregation_query(index):
9696 assert doc ["job" ] in ["engineer" , "doctor" , "dermatologist" , "CEO" , "dentist" ]
9797 assert doc ["credit_score" ] in ["high" , "low" , "medium" ]
9898
99- hybrid_query = HybridAggregationQuery (
99+ hybrid_query = HybridQuery (
100100 text = text ,
101101 text_field_name = text_field ,
102102 vector = vector ,
103103 vector_field_name = vector_field ,
104104 num_results = 3 ,
105105 )
106106
107- results = index .aggregate_query (hybrid_query )
107+ results = index .query (hybrid_query )
108108 assert len (results ) == 3
109109 assert (
110110 results [0 ]["hybrid_score" ]
@@ -122,7 +122,7 @@ def test_empty_query_string():
122122
123123 # test if text is empty
124124 with pytest .raises (ValueError ):
125- hybrid_query = HybridAggregationQuery (
125+ hybrid_query = HybridQuery (
126126 text = text ,
127127 text_field_name = text_field ,
128128 vector = vector ,
@@ -132,7 +132,7 @@ def test_empty_query_string():
132132 # test if text becomes empty after stopwords are removed
133133 text = "with a for but and" # will all be removed as default stopwords
134134 with pytest .raises (ValueError ):
135- hybrid_query = HybridAggregationQuery (
135+ hybrid_query = HybridQuery (
136136 text = text ,
137137 text_field_name = text_field ,
138138 vector = vector ,
@@ -152,7 +152,7 @@ def test_aggregation_query_with_filter(index):
152152 return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
153153 filter_expression = (Tag ("credit_score" ) == ("high" )) & (Num ("age" ) > 30 )
154154
155- hybrid_query = HybridAggregationQuery (
155+ hybrid_query = HybridQuery (
156156 text = text ,
157157 text_field_name = text_field ,
158158 vector = vector ,
@@ -161,7 +161,7 @@ def test_aggregation_query_with_filter(index):
161161 return_fields = return_fields ,
162162 )
163163
164- results = index .aggregate_query (hybrid_query )
164+ results = index .query (hybrid_query )
165165 assert len (results ) == 2
166166 for result in results :
167167 assert result ["credit_score" ] == "high"
@@ -180,7 +180,7 @@ def test_aggregation_query_with_geo_filter(index):
180180 return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
181181 filter_expression = Geo ("location" ) == GeoRadius (- 122.4194 , 37.7749 , 1000 , "m" )
182182
183- hybrid_query = HybridAggregationQuery (
183+ hybrid_query = HybridQuery (
184184 text = text ,
185185 text_field_name = text_field ,
186186 vector = vector ,
@@ -189,7 +189,7 @@ def test_aggregation_query_with_geo_filter(index):
189189 return_fields = return_fields ,
190190 )
191191
192- results = index .aggregate_query (hybrid_query )
192+ results = index .query (hybrid_query )
193193 assert len (results ) == 3
194194 for result in results :
195195 assert result ["location" ] is not None
@@ -206,15 +206,15 @@ def test_aggregate_query_alpha(index, alpha):
206206 vector = [0.1 , 0.1 , 0.5 ]
207207 vector_field = "user_embedding"
208208
209- hybrid_query = HybridAggregationQuery (
209+ hybrid_query = HybridQuery (
210210 text = text ,
211211 text_field_name = text_field ,
212212 vector = vector ,
213213 vector_field_name = vector_field ,
214214 alpha = alpha ,
215215 )
216216
217- results = index .aggregate_query (hybrid_query )
217+ results = index .query (hybrid_query )
218218 assert len (results ) == 7
219219 for result in results :
220220 score = alpha * float (result ["vector_similarity" ]) + (1 - alpha ) * float (
@@ -236,7 +236,7 @@ def test_aggregate_query_stopwords(index):
236236 vector_field = "user_embedding"
237237 alpha = 0.5
238238
239- hybrid_query = HybridAggregationQuery (
239+ hybrid_query = HybridQuery (
240240 text = text ,
241241 text_field_name = text_field ,
242242 vector = vector ,
@@ -250,7 +250,7 @@ def test_aggregate_query_stopwords(index):
250250 assert "medical" not in query_string
251251 assert "expertize" not in query_string
252252
253- results = index .aggregate_query (hybrid_query )
253+ results = index .query (hybrid_query )
254254 assert len (results ) == 7
255255 for result in results :
256256 score = alpha * float (result ["vector_similarity" ]) + (1 - alpha ) * float (
@@ -273,7 +273,7 @@ def test_aggregate_query_with_text_filter(index):
273273 filter_expression = Text (text_field ) == ("medical" )
274274
275275 # make sure we can still apply filters to the same text field we are querying
276- hybrid_query = HybridAggregationQuery (
276+ hybrid_query = HybridQuery (
277277 text = text ,
278278 text_field_name = text_field ,
279279 vector = vector ,
@@ -283,15 +283,15 @@ def test_aggregate_query_with_text_filter(index):
283283 return_fields = ["job" , "description" ],
284284 )
285285
286- results = index .aggregate_query (hybrid_query )
286+ results = index .query (hybrid_query )
287287 assert len (results ) == 2
288288 for result in results :
289289 assert "medical" in result [text_field ].lower ()
290290
291291 filter_expression = (Text (text_field ) == ("medical" )) & (
292292 (Text (text_field ) != ("research" ))
293293 )
294- hybrid_query = HybridAggregationQuery (
294+ hybrid_query = HybridQuery (
295295 text = text ,
296296 text_field_name = text_field ,
297297 vector = vector ,
@@ -301,7 +301,7 @@ def test_aggregate_query_with_text_filter(index):
301301 return_fields = ["description" ],
302302 )
303303
304- results = index .aggregate_query (hybrid_query )
304+ results = index .query (hybrid_query )
305305 assert len (results ) == 2
306306 for result in results :
307307 assert "medical" in result [text_field ].lower ()
0 commit comments