Skip to content

Commit ff44041

Browse files
abstracts AggregationQuery to follow BaseQuery calls in search index
1 parent 9b1dc18 commit ff44041

File tree

8 files changed

+99
-76
lines changed

8 files changed

+99
-76
lines changed

redisvl/index/index.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
BaseVectorQuery,
4747
CountQuery,
4848
FilterQuery,
49-
HybridAggregationQuery,
49+
HybridQuery,
5050
)
5151
from redisvl.query.filter import FilterExpression
5252
from redisvl.redis.connection import (
@@ -686,35 +686,8 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
686686
return convert_bytes(obj[0])
687687
return None
688688

689-
def aggregate_query(
690-
self, aggregation_query: AggregationQuery
691-
) -> List[Dict[str, Any]]:
692-
"""Execute an aggretation query and processes the results.
693-
694-
This method takes an AggregationHyridQuery object directly, runs the search, and
695-
handles post-processing of the search.
696-
697-
Args:
698-
aggregation_query (AggregationQuery): The aggregation query to run.
699-
700-
Returns:
701-
List[Result]: A list of search results.
702-
703-
.. code-block:: python
704-
705-
from redisvl.query import HybridAggregationQuery
706-
707-
aggregation = HybridAggregationQuery(
708-
text="the text to search for",
709-
text_field="description",
710-
vector=[0.16, -0.34, 0.98, 0.23],
711-
vector_field="embedding",
712-
num_results=3
713-
)
714-
715-
results = index.aggregate_query(aggregation_query)
716-
717-
"""
689+
def _aggregate(self, aggregation_query: AggregationQuery) -> List[Dict[str, Any]]:
690+
"""Execute an aggretation query and processes the results."""
718691
results = self.aggregate(
719692
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
720693
)
@@ -846,7 +819,7 @@ def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
846819
results = self.search(query.query, query_params=query.params)
847820
return process_results(results, query=query, schema=self.schema)
848821

849-
def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
822+
def query(self, query: Union[BaseQuery, AggregationQuery]) -> List[Dict[str, Any]]:
850823
"""Execute a query on the index.
851824
852825
This method takes a BaseQuery object directly, runs the search, and
@@ -871,7 +844,10 @@ def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
871844
results = index.query(query)
872845
873846
"""
874-
return self._query(query)
847+
if isinstance(query, AggregationQuery):
848+
return self._aggregate(query)
849+
else:
850+
return self._query(query)
875851

876852
def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator:
877853
"""Execute a given query against the index and return results in
@@ -1377,6 +1353,19 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
13771353
return convert_bytes(obj[0])
13781354
return None
13791355

1356+
async def _aggregate(
1357+
self, aggregation_query: AggregationQuery
1358+
) -> List[Dict[str, Any]]:
1359+
"""Execute an aggretation query and processes the results."""
1360+
results = await self.aggregate(
1361+
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
1362+
)
1363+
return process_aggregate_results(
1364+
results,
1365+
query=aggregation_query,
1366+
storage_type=self.schema.index.storage_type,
1367+
)
1368+
13801369
async def aggregate(self, *args, **kwargs) -> "AggregateResult":
13811370
"""Perform an aggregation operation against the index.
13821371
@@ -1500,14 +1489,16 @@ async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
15001489
results = await self.search(query.query, query_params=query.params)
15011490
return process_results(results, query=query, schema=self.schema)
15021491

1503-
async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
1492+
async def query(
1493+
self, query: Union[BaseQuery, AggregationQuery]
1494+
) -> List[Dict[str, Any]]:
15041495
"""Asynchronously execute a query on the index.
15051496
1506-
This method takes a BaseQuery object directly, runs the search, and
1507-
handles post-processing of the search.
1497+
This method takes a BaseQuery or AggregationQuery object directly, runs
1498+
the search, and handles post-processing of the search.
15081499
15091500
Args:
1510-
query (BaseQuery): The query to run.
1501+
query Union(BaseQuery, AggregationQuery): The query to run.
15111502
15121503
Returns:
15131504
List[Result]: A list of search results.
@@ -1524,7 +1515,10 @@ async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
15241515
15251516
results = await index.query(query)
15261517
"""
1527-
return await self._query(query)
1518+
if isinstance(query, AggregationQuery):
1519+
return await self._aggregate(query)
1520+
else:
1521+
return await self._query(query)
15281522

15291523
async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerator:
15301524
"""Execute a given query against the index and return results in

redisvl/query/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from redisvl.query.aggregate import AggregationQuery, HybridAggregationQuery
1+
from redisvl.query.aggregate import AggregationQuery, HybridQuery
22
from redisvl.query.query import (
33
BaseQuery,
44
BaseVectorQuery,
@@ -20,5 +20,5 @@
2020
"CountQuery",
2121
"TextQuery",
2222
"AggregationQuery",
23-
"HybridAggregationQuery",
23+
"HybridQuery",
2424
]

redisvl/query/aggregate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def __init__(self, query_string):
1919
super().__init__(query_string)
2020

2121

22-
class HybridAggregationQuery(AggregationQuery):
22+
class HybridQuery(AggregationQuery):
2323
"""
24-
HybridAggregationQuery combines text and vector search in Redis.
24+
HybridQuery combines text and vector search in Redis.
2525
It allows you to perform a hybrid search using both text and vector similarity.
2626
It scores documents based on a weighted combination of text and vector similarity.
2727
"""
@@ -45,7 +45,7 @@ def __init__(
4545
dialect: int = 2,
4646
):
4747
"""
48-
Instantiages a HybridAggregationQuery object.
48+
Instantiages a HybridQuery object.
4949
5050
Args:
5151
text (str): The text to search for.
@@ -75,12 +75,12 @@ def __init__(
7575
TypeError: If the stopwords are not a set, list, or tuple of strings.
7676
7777
.. code-block:: python
78-
from redisvl.query.aggregate import HybridAggregationQuery
78+
from redisvl.query import HybridQuery
7979
from redisvl.index import SearchIndex
8080
81-
index = SearchIndex("my_index")
81+
index = SearchIndex.from_yaml(index.yaml)
8282
83-
query = HybridAggregationQuery(
83+
query = HybridQuery(
8484
text="example text",
8585
text_field_name="text_field",
8686
vector=[0.1, 0.2, 0.3],
@@ -92,10 +92,10 @@ def __init__(
9292
num_results=10,
9393
return_fields=["field1", "field2"],
9494
stopwords="english",
95-
dialect=4,
95+
dialect=2,
9696
)
9797
98-
results = index.aggregate_query(query)
98+
results = index.query(query)
9999
"""
100100

101101
if not text.strip():

redisvl/query/query.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def __init__(
712712
text (str): The text string to perform the text search with.
713713
text_field_name (str): The name of the document field to perform text search on.
714714
text_scorer (str, optional): The text scoring algorithm to use.
715-
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, DOCNORM, DISMAX, DOCSCORE}.
715+
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, TFIDF.DOCNORM, DISMAX, DOCSCORE}.
716716
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
717717
filter_expression (Union[str, FilterExpression], optional): A filter to apply
718718
along with the text search. Defaults to None.
@@ -740,6 +740,25 @@ def __init__(
740740
Raises:
741741
ValueError: if stopwords language string cannot be loaded.
742742
TypeError: If stopwords is not a valid iterable set of strings.
743+
744+
.. code-block:: python
745+
from redisvl.query import TextQuery
746+
from redisvl.index import SearchIndex
747+
748+
index = SearchIndex.from_yaml(index.yaml)
749+
750+
query = TextQuery(
751+
text="example text",
752+
text_field_name="text_field",
753+
text_scorer="BM25STD",
754+
filter_expression=None,
755+
num_results=10,
756+
return_fields=["field1", "field2"],
757+
stopwords="english",
758+
dialect=2,
759+
)
760+
761+
results = index.query(query)
743762
"""
744763
self._text = text
745764
self._text_field = text_field_name

tests/integration/test_aggregation.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from redis.commands.search.result import Result
44

55
from redisvl.index import SearchIndex
6-
from redisvl.query import HybridAggregationQuery
6+
from redisvl.query import HybridQuery
77
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
88
from redisvl.redis.connection import compare_versions
99
from redisvl.redis.utils import array_to_buffer
@@ -70,15 +70,15 @@ def test_aggregation_query(index):
7070
vector_field = "user_embedding"
7171
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
7272

73-
hybrid_query = HybridAggregationQuery(
73+
hybrid_query = HybridQuery(
7474
text=text,
7575
text_field_name=text_field,
7676
vector=vector,
7777
vector_field_name=vector_field,
7878
return_fields=return_fields,
7979
)
8080

81-
results = index.aggregate_query(hybrid_query)
81+
results = index.query(hybrid_query)
8282
assert isinstance(results, list)
8383
assert len(results) == 7
8484
for doc in results:
@@ -96,15 +96,15 @@ def test_aggregation_query(index):
9696
assert doc["job"] in ["engineer", "doctor", "dermatologist", "CEO", "dentist"]
9797
assert doc["credit_score"] in ["high", "low", "medium"]
9898

99-
hybrid_query = HybridAggregationQuery(
99+
hybrid_query = HybridQuery(
100100
text=text,
101101
text_field_name=text_field,
102102
vector=vector,
103103
vector_field_name=vector_field,
104104
num_results=3,
105105
)
106106

107-
results = index.aggregate_query(hybrid_query)
107+
results = index.query(hybrid_query)
108108
assert len(results) == 3
109109
assert (
110110
results[0]["hybrid_score"]
@@ -122,7 +122,7 @@ def test_empty_query_string():
122122

123123
# test if text is empty
124124
with pytest.raises(ValueError):
125-
hybrid_query = HybridAggregationQuery(
125+
hybrid_query = HybridQuery(
126126
text=text,
127127
text_field_name=text_field,
128128
vector=vector,
@@ -132,7 +132,7 @@ def test_empty_query_string():
132132
# test if text becomes empty after stopwords are removed
133133
text = "with a for but and" # will all be removed as default stopwords
134134
with pytest.raises(ValueError):
135-
hybrid_query = HybridAggregationQuery(
135+
hybrid_query = HybridQuery(
136136
text=text,
137137
text_field_name=text_field,
138138
vector=vector,
@@ -152,7 +152,7 @@ def test_aggregation_query_with_filter(index):
152152
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
153153
filter_expression = (Tag("credit_score") == ("high")) & (Num("age") > 30)
154154

155-
hybrid_query = HybridAggregationQuery(
155+
hybrid_query = HybridQuery(
156156
text=text,
157157
text_field_name=text_field,
158158
vector=vector,
@@ -161,7 +161,7 @@ def test_aggregation_query_with_filter(index):
161161
return_fields=return_fields,
162162
)
163163

164-
results = index.aggregate_query(hybrid_query)
164+
results = index.query(hybrid_query)
165165
assert len(results) == 2
166166
for result in results:
167167
assert result["credit_score"] == "high"
@@ -180,7 +180,7 @@ def test_aggregation_query_with_geo_filter(index):
180180
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
181181
filter_expression = Geo("location") == GeoRadius(-122.4194, 37.7749, 1000, "m")
182182

183-
hybrid_query = HybridAggregationQuery(
183+
hybrid_query = HybridQuery(
184184
text=text,
185185
text_field_name=text_field,
186186
vector=vector,
@@ -189,7 +189,7 @@ def test_aggregation_query_with_geo_filter(index):
189189
return_fields=return_fields,
190190
)
191191

192-
results = index.aggregate_query(hybrid_query)
192+
results = index.query(hybrid_query)
193193
assert len(results) == 3
194194
for result in results:
195195
assert result["location"] is not None
@@ -206,15 +206,15 @@ def test_aggregate_query_alpha(index, alpha):
206206
vector = [0.1, 0.1, 0.5]
207207
vector_field = "user_embedding"
208208

209-
hybrid_query = HybridAggregationQuery(
209+
hybrid_query = HybridQuery(
210210
text=text,
211211
text_field_name=text_field,
212212
vector=vector,
213213
vector_field_name=vector_field,
214214
alpha=alpha,
215215
)
216216

217-
results = index.aggregate_query(hybrid_query)
217+
results = index.query(hybrid_query)
218218
assert len(results) == 7
219219
for result in results:
220220
score = alpha * float(result["vector_similarity"]) + (1 - alpha) * float(
@@ -236,7 +236,7 @@ def test_aggregate_query_stopwords(index):
236236
vector_field = "user_embedding"
237237
alpha = 0.5
238238

239-
hybrid_query = HybridAggregationQuery(
239+
hybrid_query = HybridQuery(
240240
text=text,
241241
text_field_name=text_field,
242242
vector=vector,
@@ -250,7 +250,7 @@ def test_aggregate_query_stopwords(index):
250250
assert "medical" not in query_string
251251
assert "expertize" not in query_string
252252

253-
results = index.aggregate_query(hybrid_query)
253+
results = index.query(hybrid_query)
254254
assert len(results) == 7
255255
for result in results:
256256
score = alpha * float(result["vector_similarity"]) + (1 - alpha) * float(
@@ -273,7 +273,7 @@ def test_aggregate_query_with_text_filter(index):
273273
filter_expression = Text(text_field) == ("medical")
274274

275275
# make sure we can still apply filters to the same text field we are querying
276-
hybrid_query = HybridAggregationQuery(
276+
hybrid_query = HybridQuery(
277277
text=text,
278278
text_field_name=text_field,
279279
vector=vector,
@@ -283,15 +283,15 @@ def test_aggregate_query_with_text_filter(index):
283283
return_fields=["job", "description"],
284284
)
285285

286-
results = index.aggregate_query(hybrid_query)
286+
results = index.query(hybrid_query)
287287
assert len(results) == 2
288288
for result in results:
289289
assert "medical" in result[text_field].lower()
290290

291291
filter_expression = (Text(text_field) == ("medical")) & (
292292
(Text(text_field) != ("research"))
293293
)
294-
hybrid_query = HybridAggregationQuery(
294+
hybrid_query = HybridQuery(
295295
text=text,
296296
text_field_name=text_field,
297297
vector=vector,
@@ -301,7 +301,7 @@ def test_aggregate_query_with_text_filter(index):
301301
return_fields=["description"],
302302
)
303303

304-
results = index.aggregate_query(hybrid_query)
304+
results = index.query(hybrid_query)
305305
assert len(results) == 2
306306
for result in results:
307307
assert "medical" in result[text_field].lower()

0 commit comments

Comments
 (0)