@@ -94,6 +94,7 @@ def __init__(
9494 return_fields : Optional [List [str ]] = None ,
9595 stopwords : Optional [Union [str , Set [str ]]] = "english" ,
9696 dialect : int = 2 ,
97+ text_weights : Optional [Dict [str , float ]] = None ,
9798 ):
9899 """
99100 Instantiates a HybridQuery object.
@@ -119,7 +120,9 @@ def __init__(
119120 set, or tuple of strings is provided then those will be used as stopwords.
120121 Defaults to "english". if set to "None" then no stopwords will be removed.
121122 dialect (int, optional): The Redis dialect version. Defaults to 2.
122-
123+ text_weights (Optional[Dict[str, float]): The importance weighting of individual words
124+ within the query text. Defaults to None, as no modifications will be made to the
125+ text_scorer score.
123126 Raises:
124127 ValueError: If the text string is empty, or if the text string becomes empty after
125128 stopwords are removed.
@@ -138,6 +141,7 @@ def __init__(
138141 self ._dtype = dtype
139142 self ._num_results = num_results
140143 self ._set_stopwords (stopwords )
144+ self ._text_weights = self ._parse_text_weights (text_weights )
141145
142146 query_string = self ._build_query_string ()
143147 super ().__init__ (query_string )
@@ -225,13 +229,60 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
225229 )
226230 for token in user_query .split ()
227231 ]
228- tokenized = " | " .join (
229- [token for token in tokens if token and token not in self ._stopwords ]
230- )
232+ ## tokenized = " | ".join(
233+ ## [token for token in tokens if token and token not in self._stopwords]
234+ ## )
231235
232- if not tokenized :
236+ token_list = [
237+ token for token in tokens if token and token not in self ._stopwords
238+ ]
239+ for i , token in enumerate (token_list ):
240+ if token in self ._text_weights :
241+ token_list [i ] = f"{ token } =>{{weight:{ self ._text_weights [token ]} }}"
242+
243+ if not token_list :
233244 raise ValueError ("text string cannot be empty after removing stopwords" )
234- return tokenized
245+ return " | " .join (token_list )
246+
247+ def _parse_text_weights (
248+ self , weights : Optional [Dict [str , float ]]
249+ ) -> Dict [str , float ]:
250+ parsed_weights : Dict [str , float ] = {}
251+ if not weights :
252+ return parsed_weights
253+ for word , weight in weights .items ():
254+ word = word .strip ().lower ()
255+ if not word or " " in word :
256+ raise ValueError (
257+ f"Only individual words may be weighted. Got {{ { word } :{ weight } }}"
258+ )
259+ if (
260+ not (isinstance (weight , float ) or isinstance (weight , int ))
261+ or weight < 0.0
262+ ):
263+ raise ValueError (
264+ f"Weights must be positive number. Got {{ { word } :{ weight } }}"
265+ )
266+ parsed_weights [word ] = weight
267+ return parsed_weights
268+
269+ def set_text_weights (self , weights : Dict [str , float ]):
270+ """Set or update the text weights for the query.
271+
272+ Args:
273+ text_weights: Dictionary of word:weight mappings
274+ """
275+ self ._text_weights = self ._parse_text_weights (weights )
276+ self ._built_query_string = None
277+
278+ @property
279+ def text_weights (self ) -> Dict [str , float ]:
280+ """Get the text weights.
281+
282+ Returns:
283+ Dictionary of word:weight mappings.
284+ """
285+ return self ._text_weights
235286
236287 def _build_query_string (self ) -> str :
237288 """Build the full query string for text search with optional filtering."""
0 commit comments