Skip to content

Commit f2a7f98

Browse files
adds unit tests for text word weighting in TextQuery class
1 parent ee12a6f commit f2a7f98

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

redisvl/query/query.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,9 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
11551155
)
11561156
for token in user_query.split()
11571157
]
1158-
token_list = [token for token in tokens if token and token not in self._stopwords]
1158+
token_list = [
1159+
token for token in tokens if token and token not in self._stopwords
1160+
]
11591161
for i, token in enumerate(token_list):
11601162
if token in self._text_weights:
11611163
token_list[i] = f"{token}=>{{weight:{self._text_weights[token]}}}"
@@ -1227,18 +1229,29 @@ def text_field_name(self) -> Union[str, Dict[str, float]]:
12271229
return field
12281230
return self._field_weights.copy()
12291231

1230-
def _parse_text_weights(self, weights: Dict[str, float]) -> Dict[str, float]:
1231-
parsed_weights = {}
1232+
def _parse_text_weights(
1233+
self, weights: Optional[Dict[str, float]]
1234+
) -> Dict[str, float]:
1235+
parsed_weights: Dict[str, float] = {}
1236+
if not weights:
1237+
return parsed_weights
12321238
for word, weight in weights.items():
12331239
word = word.strip().lower()
12341240
if not word or " " in word:
1235-
raise ValueError("Only individual words may be weighted. Got {{ {word}:{weight} }}")
1236-
if not isinstance(weight, float) or weight <0.0:
1237-
raise ValueError("Weights must be positive floats. Got {{ {word}:{weight} }}")
1241+
raise ValueError(
1242+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
1243+
)
1244+
if (
1245+
not (isinstance(weight, float) or isinstance(weight, int))
1246+
or weight < 0.0
1247+
):
1248+
raise ValueError(
1249+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
1250+
)
12381251
parsed_weights[word] = weight
12391252
return parsed_weights
12401253

1241-
def set_text_weights(self, weights:Dict[str, float]):
1254+
def set_text_weights(self, weights: Dict[str, float]):
12421255
"""Set or update the text weights for the query.
12431256
12441257
Args:
@@ -1248,7 +1261,7 @@ def set_text_weights(self, weights:Dict[str, float]):
12481261
self._built_query_string = None
12491262

12501263
@property
1251-
def text_weights() -> Dict[str, float]:
1264+
def text_weights(self) -> Dict[str, float]:
12521265
"""Get the text weights.
12531266
12541267
Returns:

tests/unit/test_query_types.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,64 @@ def test_text_query_with_string_filter():
333333
assert "AND" not in query_string_wildcard
334334

335335

336+
def test_text_query_word_weights():
337+
# verify word weights get added into the raw Redis query syntax
338+
query = TextQuery(
339+
text="query string alpha bravo delta tango alpha",
340+
text_field_name="description",
341+
text_weights={"alpha": 2, "delta": 0.555, "gamma": 0.95},
342+
)
343+
344+
assert (
345+
str(query)
346+
== "@description:(query | string | alpha=>{weight:2} | bravo | delta=>{weight:0.555} | tango | alpha=>{weight:2}) SCORER BM25STD WITHSCORES DIALECT 2 LIMIT 0 10"
347+
)
348+
349+
# raise an error if weights are not positive floats
350+
with pytest.raises(ValueError):
351+
_ = TextQuery(
352+
text="sample text query",
353+
text_field_name="description",
354+
text_weights={"first": 0.2, "second": -0.1},
355+
)
356+
357+
with pytest.raises(ValueError):
358+
_ = TextQuery(
359+
text="sample text query",
360+
text_field_name="description",
361+
text_weights={"first": 0.2, "second": "0.1"},
362+
)
363+
364+
# no error is weights dictiionary is empty or None
365+
query = TextQuery(
366+
text="sample text query", text_field_name="description", text_weights={}
367+
)
368+
assert query
369+
370+
query = TextQuery(
371+
text="sample text query", text_field_name="description", text_weights=None
372+
)
373+
assert query
374+
375+
# no error if the words in weights dictionary don't appear in query
376+
query = TextQuery(
377+
text="sample text query",
378+
text_field_name="description",
379+
text_weights={"alpha": 0.2, "bravo": 0.4},
380+
)
381+
assert query
382+
383+
# we can access the word weights on a query object
384+
assert query.text_weights == {"alpha": 0.2, "bravo": 0.4}
385+
386+
# we can change the text weights on a query object
387+
query.set_text_weights(weights={"new": 0.3, "words": 0.125, "here": 99})
388+
assert query.text_weights == {"new": 0.3, "words": 0.125, "here": 99}
389+
390+
query.set_text_weights(weights={})
391+
assert query.text_weights == {}
392+
393+
336394
@pytest.mark.parametrize(
337395
"query",
338396
[

0 commit comments

Comments
 (0)