-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathaggregate.py
More file actions
388 lines (315 loc) · 13.6 KB
/
aggregate.py
File metadata and controls
388 lines (315 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import warnings
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, field_validator, model_validator
from redis.commands.search.aggregation import AggregateRequest, Desc
from typing_extensions import Self
from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType
from redisvl.utils.full_text_query_helper import FullTextQueryHelper
from redisvl.utils.utils import lazy_import
nltk = lazy_import("nltk")
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
class Vector(BaseModel):
"""
Simple object containing the necessary arguments to perform a multi vector query.
"""
vector: Union[List[float], bytes]
field_name: str
dtype: str = "float32"
weight: float = 1.0
@field_validator("dtype")
@classmethod
def validate_dtype(cls, dtype: str) -> str:
try:
VectorDataType(dtype.upper())
except ValueError:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)
return dtype
@model_validator(mode="after")
def validate_vector(self) -> Self:
"""If the vector passed in is an array of float convert it to a byte string."""
if isinstance(self.vector, bytes):
return self
self.vector = array_to_buffer(self.vector, self.dtype)
return self
class AggregationQuery(AggregateRequest):
"""
Base class for aggregation queries used to create aggregation queries for Redis.
"""
def __init__(self, query_string):
super().__init__(query_string)
class AggregateHybridQuery(AggregationQuery):
"""
AggregateHybridQuery combines text and vector search in Redis.
It allows you to perform a hybrid search using both text and vector similarity.
It scores documents based on a weighted combination of text and vector similarity.
.. code-block:: python
from redisvl.query import AggregateHybridQuery
from redisvl.index import SearchIndex
index = SearchIndex.from_yaml("path/to/index.yaml")
query = AggregateHybridQuery(
text="example text",
text_field_name="text_field",
vector=[0.1, 0.2, 0.3],
vector_field_name="vector_field",
text_scorer="BM25STD",
filter_expression=None,
alpha=0.7,
dtype="float32",
num_results=10,
return_fields=["field1", "field2"],
stopwords="english",
dialect=2,
)
results = index.query(query)
"""
DISTANCE_ID: str = "vector_distance"
VECTOR_PARAM: str = "vector"
def __init__(
self,
text: str,
text_field_name: str,
vector: Union[bytes, List[float]],
vector_field_name: str,
text_scorer: str = "BM25STD",
filter_expression: Optional[Union[str, FilterExpression]] = None,
alpha: float = 0.7,
dtype: str = "float32",
num_results: int = 10,
return_fields: Optional[List[str]] = None,
stopwords: Optional[Union[str, Set[str]]] = "english",
dialect: int = 2,
text_weights: Optional[Dict[str, float]] = None,
):
"""
Instantiates a AggregateHybridQuery object.
Args:
text (str): The text to search for.
text_field_name (str): The text field name to search in.
vector (Union[bytes, List[float]]): The vector to perform vector similarity search.
vector_field_name (str): The vector field name to search in.
text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD".
filter_expression (Optional[FilterExpression], optional): The filter expression to use.
Defaults to None.
alpha (float, optional): The weight of the vector similarity. Documents will be scored
as: hybrid_score = (alpha) * vector_score + (1-alpha) * text_score.
Defaults to 0.7.
dtype (str, optional): The data type of the vector. Defaults to "float32".
num_results (int, optional): The number of results to return. Defaults to 10.
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the
provided text prior to search-use. If a string such as "english" "german" is
provided then a default set of stopwords for that language will be used. if a list,
set, or tuple of strings is provided then those will be used as stopwords.
Defaults to "english". if set to "None" then no stopwords will be removed.
Note: This parameter controls query-time stopword filtering (client-side).
For index-level stopwords configuration (server-side), see IndexInfo.stopwords.
Using query-time stopwords with index-level STOPWORDS 0 is counterproductive.
dialect (int, optional): The Redis dialect version. Defaults to 2.
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
within the query text. Defaults to None, as no modifications will be made to the
text_scorer score.
Note:
AggregateHybridQuery uses FT.AGGREGATE commands which do NOT support runtime
parameters. For runtime parameter support (ef_runtime, search_window_size, etc.),
use VectorQuery or VectorRangeQuery which use FT.SEARCH commands.
Raises:
ValueError: If the text string is empty, or if the text string becomes empty after
stopwords are removed.
TypeError: If the stopwords are not a set, list, or tuple of strings.
"""
if not text.strip():
raise ValueError("text string cannot be empty")
self._text = text
self._text_field = text_field_name
self._vector = vector
self._vector_field = vector_field_name
self._filter_expression = filter_expression
self._alpha = alpha
self._dtype = dtype
self._num_results = num_results
self._ft_helper = FullTextQueryHelper(
stopwords=stopwords,
text_weights=text_weights,
)
query_string = self._build_query_string()
super().__init__(query_string)
self.scorer(text_scorer)
self.add_scores()
self.apply(
vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score"
)
self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity")
self.sort_by(Desc("@hybrid_score"), max=num_results) # type: ignore
self.dialect(dialect)
if return_fields:
self.load(*return_fields) # type: ignore[arg-type]
@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the aggregation.
Returns:
Dict[str, Any]: The parameters for the aggregation.
"""
if isinstance(self._vector, list):
vector = array_to_buffer(self._vector, dtype=self._dtype)
else:
vector = self._vector
params: Dict[str, Any] = {self.VECTOR_PARAM: vector}
return params
@property
def stopwords(self) -> Set[str]:
"""Return the stopwords used in the query.
Returns:
Set[str]: The stopwords used in the query.
"""
return self._ft_helper.stopwords
@property
def text_weights(self) -> Dict[str, float]:
"""Get the text weights.
Returns:
Dictionary of word:weight mappings.
"""
return self._ft_helper.text_weights
def set_text_weights(self, weights: Dict[str, float]):
"""Set or update the text weights for the query.
Args:
weights: Dictionary of word:weight mappings
"""
self._ft_helper.set_text_weights(weights)
self._query = self._build_query_string()
def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
text = self._ft_helper.build_query_string(
self._text, self._text_field, self._filter_expression
)
# Build KNN query
knn_query = (
f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM}"
)
# Add distance field alias
knn_query += f" AS {self.DISTANCE_ID}"
return f"{text}=>[{knn_query}]"
def __str__(self) -> str:
"""Return the string representation of the query."""
return " ".join([str(x) for x in self.build_args()])
class MultiVectorQuery(AggregationQuery):
"""
MultiVectorQuery allows for search over multiple vector fields in a document simultaneously.
The final score will be a weighted combination of the individual vector similarity scores
following the formula:
score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )
Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric.
.. code-block:: python
from redisvl.query import MultiVectorQuery, Vector
from redisvl.index import SearchIndex
index = SearchIndex.from_yaml("path/to/index.yaml")
vector_1 = Vector(
vector=[0.1, 0.2, 0.3],
field_name="text_vector",
dtype="float32",
weight=0.7,
)
vector_2 = Vector(
vector=[0.5, 0.5],
field_name="image_vector",
dtype="bfloat16",
weight=0.2,
)
vector_3 = Vector(
vector=[0.1, 0.2, 0.3],
field_name="text_vector",
dtype="float64",
weight=0.5,
)
query = MultiVectorQuery(
vectors=[vector_1, vector_2, vector_3],
filter_expression=None,
num_results=10,
return_fields=["field1", "field2"],
dialect=2,
)
results = index.query(query)
"""
_vectors: List[Vector]
def __init__(
self,
vectors: Union[Vector, List[Vector]],
return_fields: Optional[List[str]] = None,
filter_expression: Optional[Union[str, FilterExpression]] = None,
num_results: int = 10,
dialect: int = 2,
):
"""
Instantiates a MultiVectorQuery object.
Args:
vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
Defaults to None.
num_results (int, optional): The number of results to return. Defaults to 10.
dialect (int, optional): The Redis dialect version. Defaults to 2.
"""
self._filter_expression = filter_expression
self._num_results = num_results
if isinstance(vectors, Vector):
self._vectors = [vectors]
else:
self._vectors = vectors # type: ignore
if not all([isinstance(v, Vector) for v in self._vectors]):
raise TypeError(
"vector argument must be a Vector object or list of Vector objects."
)
query_string = self._build_query_string()
super().__init__(query_string)
# calculate the respective vector similarities
for i in range(len(self._vectors)):
self.apply(
**{
f"score_{i}": f"case(exists(@distance_{i}), (2 - @distance_{i})/2, 0)"
}
)
# construct the scoring string based on the vector similarity scores and weights
combined_scores = []
for i, w in enumerate([v.weight for v in self._vectors]):
combined_scores.append(f"@score_{i} * {w}")
combined_score_string = " + ".join(combined_scores)
self.apply(combined_score=combined_score_string)
self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore
self.dialect(dialect)
if return_fields:
self.load(*return_fields) # type: ignore[arg-type]
@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the aggregation.
Returns:
Dict[str, Any]: The parameters for the aggregation.
"""
params = {}
for i, v in enumerate(self._vectors):
params[f"vector_{i}"] = v.vector
return params
def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
# base KNN query
range_queries = []
for i, (vector, field) in enumerate(
[(v.vector, v.field_name) for v in self._vectors]
):
range_queries.append(
f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"
)
range_query = " | ".join(range_queries)
filter_expression = self._filter_expression
if isinstance(self._filter_expression, FilterExpression):
filter_expression = str(self._filter_expression)
if filter_expression:
return f"({range_query}) AND ({filter_expression})"
else:
return f"{range_query}"
def __str__(self) -> str:
"""Return the string representation of the query."""
return " ".join([str(x) for x in self.build_args()])