Skip to content

Commit be1da05

Browse files
refactor: Refactored how to disable stemming in bm25
1 parent cf22af3 commit be1da05

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

fastembed/sparse/bm25.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class Bm25(SparseTextEmbeddingBase):
9292
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
9393
Defaults to 0.75.
9494
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
95+
language (str, optional): Specifies the language for the stemmer. Set to None to disable stemming.
9596
Raises:
9697
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
9798
"""
@@ -103,14 +104,13 @@ def __init__(
103104
k: float = 1.2,
104105
b: float = 0.75,
105106
avg_len: float = 256.0,
106-
language: str = "english",
107+
language: Optional[str] = "english",
107108
token_max_length: int = 40,
108-
disable_stemmer: bool = False,
109109
**kwargs,
110110
):
111111
super().__init__(model_name, cache_dir, **kwargs)
112112

113-
if language not in supported_languages:
113+
if language is not None and language not in supported_languages:
114114
raise ValueError(f"{language} language is not supported")
115115
else:
116116
self.language = language
@@ -130,8 +130,7 @@ def __init__(
130130
self.punctuation = set(get_all_punctuation())
131131
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
132132

133-
self.disable_stemmer = disable_stemmer
134-
self.stemmer = SnowballStemmer(language) if not disable_stemmer else None
133+
self.stemmer = SnowballStemmer(language)
135134
self.tokenizer = SimpleTokenizer
136135

137136
@classmethod
@@ -144,7 +143,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
144143
return supported_bm25_models
145144

146145
@classmethod
147-
def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]:
146+
def _load_stopwords(cls, model_dir: Path, language: Optional[str]) -> list[str]:
148147
stopwords_path = model_dir / f"{language}.txt"
149148
if not stopwords_path.exists():
150149
return []
@@ -225,9 +224,6 @@ def embed(
225224
)
226225

227226
def _stem(self, tokens: list[str]) -> list[str]:
228-
if self.disable_stemmer:
229-
return tokens
230-
231227
stemmed_tokens = []
232228
for token in tokens:
233229
if token in self.punctuation:
@@ -239,7 +235,10 @@ def _stem(self, tokens: list[str]) -> list[str]:
239235
if len(token) > self.token_max_length:
240236
continue
241237

242-
stemmed_token = self.stemmer.stem_word(token.lower())
238+
if self.stemmer:
239+
stemmed_token = self.stemmer.stem_word(token.lower())
240+
else:
241+
stemmed_token = token.lower()
243242

244243
if stemmed_token:
245244
stemmed_tokens.append(stemmed_token)

0 commit comments

Comments
 (0)