@@ -92,6 +92,7 @@ class Bm25(SparseTextEmbeddingBase):
9292 b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
9393 Defaults to 0.75.
9494 avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
95+ language (str, optional): Specifies the language for the stemmer. Set to None to disable stemming.
9596 Raises:
9697 ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
9798 """
@@ -103,14 +104,13 @@ def __init__(
103104 k : float = 1.2 ,
104105 b : float = 0.75 ,
105106 avg_len : float = 256.0 ,
106- language : str = "english" ,
107+ language : Optional [ str ] = "english" ,
107108 token_max_length : int = 40 ,
108- disable_stemmer : bool = False ,
109109 ** kwargs ,
110110 ):
111111 super ().__init__ (model_name , cache_dir , ** kwargs )
112112
113- if language not in supported_languages :
113+ if language is not None and language not in supported_languages :
114114 raise ValueError (f"{ language } language is not supported" )
115115 else :
116116 self .language = language
@@ -130,8 +130,7 @@ def __init__(
130130 self .punctuation = set (get_all_punctuation ())
131131 self .stopwords = set (self ._load_stopwords (self ._model_dir , self .language ))
132132
133- self .disable_stemmer = disable_stemmer
134- self .stemmer = SnowballStemmer (language ) if not disable_stemmer else None
133+ self .stemmer = SnowballStemmer (language )
135134 self .tokenizer = SimpleTokenizer
136135
137136 @classmethod
@@ -144,7 +143,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
144143 return supported_bm25_models
145144
146145 @classmethod
147- def _load_stopwords (cls , model_dir : Path , language : str ) -> list [str ]:
146+ def _load_stopwords (cls , model_dir : Path , language : Optional [ str ] ) -> list [str ]:
148147 stopwords_path = model_dir / f"{ language } .txt"
149148 if not stopwords_path .exists ():
150149 return []
@@ -225,9 +224,6 @@ def embed(
225224 )
226225
227226 def _stem (self , tokens : list [str ]) -> list [str ]:
228- if self .disable_stemmer :
229- return tokens
230-
231227 stemmed_tokens = []
232228 for token in tokens :
233229 if token in self .punctuation :
@@ -239,7 +235,10 @@ def _stem(self, tokens: list[str]) -> list[str]:
239235 if len (token ) > self .token_max_length :
240236 continue
241237
242- stemmed_token = self .stemmer .stem_word (token .lower ())
238+ if self .stemmer :
239+ stemmed_token = self .stemmer .stem_word (token .lower ())
240+ else :
241+ stemmed_token = token .lower ()
243242
244243 if stemmed_token :
245244 stemmed_tokens .append (stemmed_token )
0 commit comments