Skip to content

Commit ee12a6f

Browse files
wip: adding word weights to TextQuery string
1 parent f1b592f commit ee12a6f

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

redisvl/query/query.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,7 @@ def __init__(
10281028
in_order: bool = False,
10291029
params: Optional[Dict[str, Any]] = None,
10301030
stopwords: Optional[Union[str, Set[str]]] = "english",
1031+
text_weights: Optional[Dict[str, float]] = None,
10311032
):
10321033
"""A query for running a full text search, along with an optional filter expression.
10331034
@@ -1064,13 +1065,16 @@ def __init__(
10641065
a default set of stopwords for that language will be used. Users may specify
10651066
their own stop words by providing a List or Set of words. if set to None,
10661067
then no words will be removed. Defaults to 'english'.
1067-
1068+
text_weights (Optional[Dict[str, float]): The importance weighting of individual words
1069+
within the query text. Defaults to None, as no modifications will be made to the
1070+
text_scorer score.
10681071
Raises:
10691072
ValueError: if stopwords language string cannot be loaded.
10701073
TypeError: If stopwords is not a valid iterable set of strings.
10711074
"""
10721075
self._text = text
10731076
self._field_weights = self._parse_field_weights(text_field_name)
1077+
self._text_weights = self._parse_text_weights(text_weights)
10741078
self._num_results = num_results
10751079

10761080
self._set_stopwords(stopwords)
@@ -1151,9 +1155,12 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
11511155
)
11521156
for token in user_query.split()
11531157
]
1154-
return " | ".join(
1155-
[token for token in tokens if token and token not in self._stopwords]
1156-
)
1158+
token_list = [token for token in tokens if token and token not in self._stopwords]
1159+
for i, token in enumerate(token_list):
1160+
if token in self._text_weights:
1161+
token_list[i] = f"{token}=>{{weight:{self._text_weights[token]}}}"
1162+
1163+
return " | ".join(token_list)
11571164

11581165
def _parse_field_weights(
11591166
self, field_spec: Union[str, Dict[str, float]]
@@ -1220,6 +1227,35 @@ def text_field_name(self) -> Union[str, Dict[str, float]]:
12201227
return field
12211228
return self._field_weights.copy()
12221229

1230+
def _parse_text_weights(self, weights: Dict[str, float]) -> Dict[str, float]:
1231+
parsed_weights = {}
1232+
for word, weight in weights.items():
1233+
word = word.strip().lower()
1234+
if not word or " " in word:
1235+
raise ValueError("Only individual words may be weighted. Got {{ {word}:{weight} }}")
1236+
if not isinstance(weight, float) or weight <0.0:
1237+
raise ValueError("Weights must be positive floats. Got {{ {word}:{weight} }}")
1238+
parsed_weights[word] = weight
1239+
return parsed_weights
1240+
1241+
def set_text_weights(self, weights:Dict[str, float]):
1242+
"""Set or update the text weights for the query.
1243+
1244+
Args:
1245+
text_weights: Dictionary of word:weight mappings
1246+
"""
1247+
self._text_weights = self._parse_text_weights(weights)
1248+
self._built_query_string = None
1249+
1250+
@property
1251+
def text_weights() -> Dict[str, float]:
1252+
"""Get the text weights.
1253+
1254+
Returns:
1255+
Dictionary of word:weight mappings.
1256+
"""
1257+
return self._text_weights
1258+
12231259
def _build_query_string(self) -> str:
12241260
"""Build the full query string for text search with optional filtering."""
12251261
filter_expression = self._filter_expression

0 commit comments

Comments
 (0)