Skip to content

Commit 299566d

Browse files
adds word weights to HybridQuery class
1 parent f2a7f98 commit 299566d

File tree

2 files changed

+134
-6
lines changed

2 files changed

+134
-6
lines changed

redisvl/query/aggregate.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
return_fields: Optional[List[str]] = None,
9595
stopwords: Optional[Union[str, Set[str]]] = "english",
9696
dialect: int = 2,
97+
text_weights: Optional[Dict[str, float]] = None,
9798
):
9899
"""
99100
Instantiates a HybridQuery object.
@@ -119,7 +120,9 @@ def __init__(
119120
set, or tuple of strings is provided then those will be used as stopwords.
120121
Defaults to "english". if set to "None" then no stopwords will be removed.
121122
dialect (int, optional): The Redis dialect version. Defaults to 2.
122-
123+
text_weights (Optional[Dict[str, float]): The importance weighting of individual words
124+
within the query text. Defaults to None, as no modifications will be made to the
125+
text_scorer score.
123126
Raises:
124127
ValueError: If the text string is empty, or if the text string becomes empty after
125128
stopwords are removed.
@@ -138,6 +141,7 @@ def __init__(
138141
self._dtype = dtype
139142
self._num_results = num_results
140143
self._set_stopwords(stopwords)
144+
self._text_weights = self._parse_text_weights(text_weights)
141145

142146
query_string = self._build_query_string()
143147
super().__init__(query_string)
@@ -225,13 +229,60 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
225229
)
226230
for token in user_query.split()
227231
]
228-
tokenized = " | ".join(
229-
[token for token in tokens if token and token not in self._stopwords]
230-
)
232+
##tokenized = " | ".join(
233+
## [token for token in tokens if token and token not in self._stopwords]
234+
##)
231235

232-
if not tokenized:
236+
token_list = [
237+
token for token in tokens if token and token not in self._stopwords
238+
]
239+
for i, token in enumerate(token_list):
240+
if token in self._text_weights:
241+
token_list[i] = f"{token}=>{{weight:{self._text_weights[token]}}}"
242+
243+
if not token_list:
233244
raise ValueError("text string cannot be empty after removing stopwords")
234-
return tokenized
245+
return " | ".join(token_list)
246+
247+
def _parse_text_weights(
248+
self, weights: Optional[Dict[str, float]]
249+
) -> Dict[str, float]:
250+
parsed_weights: Dict[str, float] = {}
251+
if not weights:
252+
return parsed_weights
253+
for word, weight in weights.items():
254+
word = word.strip().lower()
255+
if not word or " " in word:
256+
raise ValueError(
257+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
258+
)
259+
if (
260+
not (isinstance(weight, float) or isinstance(weight, int))
261+
or weight < 0.0
262+
):
263+
raise ValueError(
264+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
265+
)
266+
parsed_weights[word] = weight
267+
return parsed_weights
268+
269+
def set_text_weights(self, weights: Dict[str, float]):
270+
"""Set or update the text weights for the query.
271+
272+
Args:
273+
text_weights: Dictionary of word:weight mappings
274+
"""
275+
self._text_weights = self._parse_text_weights(weights)
276+
self._built_query_string = None
277+
278+
@property
279+
def text_weights(self) -> Dict[str, float]:
280+
"""Get the text weights.
281+
282+
Returns:
283+
Dictionary of word:weight mappings.
284+
"""
285+
return self._text_weights
235286

236287
def _build_query_string(self) -> str:
237288
"""Build the full query string for text search with optional filtering."""

tests/unit/test_aggregation_types.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,83 @@ def test_hybrid_query_with_string_filter():
196196
assert "AND" not in query_string_wildcard
197197

198198

199+
def test_hybrid_query_text_weights():
200+
# verify word weights get added into the raw Redis query syntax
201+
vector = [0.1, 0.1, 0.5]
202+
vector_field = "user_embedding"
203+
204+
query = HybridQuery(
205+
text="query string alpha bravo delta tango alpha",
206+
text_field_name="description",
207+
vector=vector,
208+
vector_field_name=vector_field,
209+
text_weights={"alpha": 2, "delta": 0.555, "gamma": 0.95},
210+
)
211+
212+
assert (
213+
str(query)
214+
== "(~@description:(query | string | alpha=>{weight:2} | bravo | delta=>{weight:0.555} | tango | alpha=>{weight:2}))=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25STD ADDSCORES DIALECT 2 APPLY (2 - @vector_distance)/2 AS vector_similarity APPLY @__score AS text_score APPLY 0.30000000000000004*@text_score + 0.7*@vector_similarity AS hybrid_score SORTBY 2 @hybrid_score DESC MAX 10"
215+
)
216+
217+
# raise an error if weights are not positive floats
218+
with pytest.raises(ValueError):
219+
_ = HybridQuery(
220+
text="sample text query",
221+
text_field_name="description",
222+
vector=vector,
223+
vector_field_name=vector_field,
224+
text_weights={"first": 0.2, "second": -0.1},
225+
)
226+
227+
with pytest.raises(ValueError):
228+
_ = HybridQuery(
229+
text="sample text query",
230+
text_field_name="description",
231+
vector=vector,
232+
vector_field_name=vector_field,
233+
text_weights={"first": 0.2, "second": "0.1"},
234+
)
235+
236+
# no error is weights dictiionary is empty or None
237+
query = HybridQuery(
238+
text="sample text query",
239+
text_field_name="description",
240+
vector=vector,
241+
vector_field_name=vector_field,
242+
text_weights={},
243+
)
244+
assert query
245+
246+
query = HybridQuery(
247+
text="sample text query",
248+
text_field_name="description",
249+
vector=vector,
250+
vector_field_name=vector_field,
251+
text_weights=None,
252+
)
253+
assert query
254+
255+
# no error if the words in weights dictionary don't appear in query
256+
query = HybridQuery(
257+
text="sample text query",
258+
text_field_name="description",
259+
vector=vector,
260+
vector_field_name=vector_field,
261+
text_weights={"alpha": 0.2, "bravo": 0.4},
262+
)
263+
assert query
264+
265+
# we can access the word weights on a query object
266+
assert query.text_weights == {"alpha": 0.2, "bravo": 0.4}
267+
268+
# we can change the text weights on a query object
269+
query.set_text_weights(weights={"new": 0.3, "words": 0.125, "here": 99})
270+
assert query.text_weights == {"new": 0.3, "words": 0.125, "here": 99}
271+
272+
query.set_text_weights(weights={})
273+
assert query.text_weights == {}
274+
275+
199276
def test_multi_vector_query():
200277
# test we require Vector objects
201278
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)