Skip to content

Commit 9b1dc18

Browse files
makes methods private
1 parent f32067a commit 9b1dc18

File tree

3 files changed

+23
-40
lines changed

3 files changed

+23
-40
lines changed

redisvl/query/aggregate.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
self._alpha = alpha
110110
self._dtype = dtype
111111
self._num_results = num_results
112-
self.set_stopwords(stopwords)
112+
self._set_stopwords(stopwords)
113113

114114
query_string = self._build_query_string()
115115
super().__init__(query_string)
@@ -149,7 +149,7 @@ def stopwords(self) -> Set[str]:
149149
"""
150150
return self._stopwords.copy() if self._stopwords else set()
151151

152-
def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
152+
def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
153153
"""Set the stopwords to use in the query.
154154
Args:
155155
stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string
@@ -164,7 +164,7 @@ def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
164164
self._stopwords = set()
165165
elif isinstance(stopwords, str):
166166
try:
167-
nltk.download("stopwords")
167+
nltk.download("stopwords", quiet=True)
168168
self._stopwords = set(nltk_stopwords.words(stopwords))
169169
except Exception as e:
170170
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
@@ -175,7 +175,7 @@ def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
175175
else:
176176
raise TypeError("stopwords must be a set, list, or tuple of strings")
177177

178-
def tokenize_and_escape_query(self, user_query: str) -> str:
178+
def _tokenize_and_escape_query(self, user_query: str) -> str:
179179
"""Convert a raw user query to a redis full text query joined by ORs
180180
Args:
181181
user_query (str): The user query to tokenize and escape.
@@ -185,7 +185,6 @@ def tokenize_and_escape_query(self, user_query: str) -> str:
185185
Raises:
186186
ValueError: If the text string becomes empty after stopwords are removed.
187187
"""
188-
189188
escaper = TokenEscaper()
190189

191190
tokens = [
@@ -212,7 +211,7 @@ def _build_query_string(self) -> str:
212211
# base KNN query
213212
knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}"
214213

215-
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)})"
214+
text = f"(~@{self._text_field}:({self._tokenize_and_escape_query(self._text)})"
216215

217216
if filter_expression and filter_expression != "*":
218217
text += f" AND {filter_expression}"

redisvl/query/query.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from enum import Enum
22
from typing import Any, Dict, List, Optional, Set, Tuple, Union
33

4+
import nltk
5+
from nltk.corpus import stopwords as nltk_stopwords
46
from redis.commands.search.query import Query as RedisQuery
57

68
from redisvl.query.filter import FilterExpression
@@ -741,22 +743,21 @@ def __init__(
741743
"""
742744
self._text = text
743745
self._text_field = text_field_name
744-
self._text_scorer = text_scorer
745746
self._num_results = num_results
746747

747-
self.set_stopwords(stopwords)
748+
self._set_stopwords(stopwords)
748749
self.set_filter(filter_expression)
749750

750751
if params:
751752
self._params = params
752753

753-
self._num_results = num_results
754-
755754
# initialize the base query with the full query string and filter expression
756755
query_string = self._build_query_string()
757756
super().__init__(query_string)
758757

759-
# Handle query settings
758+
# handle query settings
759+
self.scorer(text_scorer)
760+
760761
if return_fields:
761762
self.return_fields(*return_fields)
762763
self.paging(0, self._num_results).dialect(dialect)
@@ -774,15 +775,12 @@ def __init__(
774775
def stopwords(self):
775776
return self._stopwords
776777

777-
def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
778+
def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
778779
if not stopwords:
779780
self._stopwords = set()
780781
elif isinstance(stopwords, str):
781782
try:
782-
import nltk
783-
from nltk.corpus import stopwords as nltk_stopwords
784-
785-
nltk.download("stopwords")
783+
nltk.download("stopwords", quiet=True)
786784
self._stopwords = set(nltk_stopwords.words(stopwords))
787785
except Exception as e:
788786
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
@@ -793,9 +791,16 @@ def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
793791
else:
794792
raise TypeError("stopwords must be a set, list, or tuple of strings")
795793

796-
def tokenize_and_escape_query(self, user_query: str) -> str:
797-
"""Convert a raw user query to a redis full text query joined by ORs"""
794+
def _tokenize_and_escape_query(self, user_query: str) -> str:
795+
"""Convert a raw user query to a redis full text query joined by ORs
796+
Args:
797+
user_query (str): The user query to tokenize and escape.
798798
799+
Returns:
800+
str: The tokenized and escaped query string.
801+
Raises:
802+
ValueError: If the text string becomes empty after stopwords are removed.
803+
"""
799804
escaper = TokenEscaper()
800805

801806
tokens = [
@@ -816,7 +821,7 @@ def _build_query_string(self) -> str:
816821
else:
817822
filter_expression = ""
818823

819-
text = f"@{self._text_field}:({self.tokenize_and_escape_query(self._text)})"
824+
text = f"@{self._text_field}:({self._tokenize_and_escape_query(self._text)})"
820825
if filter_expression and filter_expression != "*":
821826
text += f" AND {filter_expression}"
822827
return text

tests/unit/test_query_types.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,9 @@ def test_text_query():
203203
assert text_query._return_fields == return_fields
204204
assert text_query._num_results == 10
205205

206-
assert (
207-
text_query._build_query_string()
208-
== f"@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)})"
209-
)
210206
assert isinstance(text_query, Query)
211207
assert isinstance(text_query.query, Query)
212208
assert isinstance(text_query.params, dict)
213-
assert text_query._text_scorer == "BM25STD"
214209
assert text_query.params == {}
215210
assert text_query._dialect == 2
216211
assert text_query._in_order == False
@@ -250,17 +245,9 @@ def test_text_query():
250245
# Test stopwords are configurable
251246
text_query = TextQuery(text_string, text_field_name, stopwords=None)
252247
assert text_query.stopwords == set([])
253-
assert (
254-
text_query._build_query_string()
255-
== f"@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)})"
256-
)
257248

258249
text_query = TextQuery(text_string, text_field_name, stopwords=["the", "a", "of"])
259250
assert text_query.stopwords == set(["the", "a", "of"])
260-
assert (
261-
text_query._build_query_string()
262-
== f"@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)})"
263-
)
264251

265252
text_query = TextQuery(text_string, text_field_name, stopwords="german")
266253
assert text_query.stopwords != set([])
@@ -273,21 +260,13 @@ def test_text_query():
273260

274261
text_query = TextQuery(text_string, text_field_name, stopwords=["the", "a", "of"])
275262
assert text_query.stopwords == set(["the", "a", "of"])
276-
assert (
277-
text_query._build_query_string()
278-
== f"@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)})"
279-
)
280263

281264
text_query = TextQuery(text_string, text_field_name, stopwords="german")
282265
assert text_query.stopwords != set([])
283266

284267
# test that filter expression is set correctly
285268
text_query.set_filter(filter_expression)
286269
assert text_query.filter == filter_expression
287-
assert (
288-
text_query._build_query_string()
289-
== f"@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)}) AND {filter_expression}"
290-
)
291270

292271
with pytest.raises(ValueError):
293272
text_query = TextQuery(text_string, text_field_name, stopwords="gibberish")

0 commit comments

Comments
 (0)