@@ -110,7 +110,7 @@ def __init__(
110110 self .device_id = None
111111
112112 self .model_description = self ._get_model_description (model_name )
113- self .cache_dir = define_cache_dir (cache_dir )
113+ self .cache_dir = str ( define_cache_dir (cache_dir ) )
114114
115115 self ._model_dir = self .download_model (
116116 self .model_description ,
@@ -119,10 +119,10 @@ def __init__(
119119 specific_model_path = specific_model_path ,
120120 )
121121
122- self .invert_vocab = {}
122+ self .invert_vocab : dict [ int , str ] = {}
123123
124- self .special_tokens = set ()
125- self .special_tokens_ids = set ()
124+ self .special_tokens : set [ str ] = set ()
125+ self .special_tokens_ids : set [ int ] = set ()
126126 self .punctuation = set (string .punctuation )
127127 self .stopwords = set (self ._load_stopwords (self ._model_dir ))
128128 self .stemmer = SnowballStemmer (MODEL_TO_LANGUAGE [model_name ])
@@ -147,15 +147,15 @@ def load_onnx_model(self) -> None:
147147 self .stopwords = set (self ._load_stopwords (self ._model_dir ))
148148
149149 def _filter_pair_tokens (self , tokens : list [tuple [str , Any ]]) -> list [tuple [str , Any ]]:
150- result = []
150+ result : list [ tuple [ str , Any ]] = []
151151 for token , value in tokens :
152152 if token in self .stopwords or token in self .punctuation :
153153 continue
154154 result .append ((token , value ))
155155 return result
156156
157157 def _stem_pair_tokens (self , tokens : list [tuple [str , Any ]]) -> list [tuple [str , Any ]]:
158- result = []
158+ result : list [ tuple [ str , Any ]] = []
159159 for token , value in tokens :
160160 processed_token = self .stemmer .stem_word (token )
161161 result .append ((processed_token , value ))
@@ -165,7 +165,7 @@ def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, An
165165 def _aggregate_weights (
166166 cls , tokens : list [tuple [str , list [int ]]], weights : list [float ]
167167 ) -> list [tuple [str , float ]]:
168- result = []
168+ result : list [ tuple [ str , float ]] = []
169169 for token , idxs in tokens :
170170 sum_weight = sum (weights [idx ] for idx in idxs )
171171 result .append ((token , sum_weight ))
@@ -174,9 +174,9 @@ def _aggregate_weights(
174174 def _reconstruct_bpe (
175175 self , bpe_tokens : Iterable [tuple [int , str ]]
176176 ) -> list [tuple [str , list [int ]]]:
177- result = []
178- acc = ""
179- acc_idx = []
177+ result : list [ tuple [ str , list [ int ]]] = []
178+ acc : str = ""
179+ acc_idx : list [ int ] = []
180180
181181 continuing_subword_prefix = self .tokenizer .model .continuing_subword_prefix
182182 continuing_subword_prefix_len = len (continuing_subword_prefix )
@@ -206,7 +206,7 @@ def _rescore_vector(self, vector: dict[str, float]) -> dict[int, float]:
206206 So that the scoring doesn't depend on absolute values assigned by the model, but on the relative importance.
207207 """
208208
209- new_vector = {}
209+ new_vector : dict [ int , float ] = {}
210210
211211 for token , value in vector .items ():
212212 token_id = abs (mmh3 .hash (token ))
@@ -241,7 +241,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Spars
241241
242242 weighted = self ._aggregate_weights (stemmed , attention_value )
243243
244- max_token_weight = {}
244+ max_token_weight : dict [ str , float ] = {}
245245
246246 for token , weight in weighted :
247247 max_token_weight [token ] = max (max_token_weight .get (token , 0 ), weight )
@@ -304,7 +304,7 @@ def embed(
304304
305305 @classmethod
306306 def _query_rehash (cls , tokens : Iterable [str ]) -> dict [int , float ]:
307- result = {}
307+ result : dict [ int , float ] = {}
308308 for token in tokens :
309309 token_id = abs (mmh3 .hash (token ))
310310 result [token_id ] = 1.0
@@ -334,11 +334,11 @@ def query_embed(
334334 yield SparseEmbedding .from_dict (self ._query_rehash (token for token , _ in stemmed ))
335335
336336 @classmethod
337- def _get_worker_class (cls ) -> Type [TextEmbeddingWorker ]:
337+ def _get_worker_class (cls ) -> Type [TextEmbeddingWorker [ SparseEmbedding ] ]:
338338 return Bm42TextEmbeddingWorker
339339
340340
341- class Bm42TextEmbeddingWorker (TextEmbeddingWorker ):
341+ class Bm42TextEmbeddingWorker (TextEmbeddingWorker [ SparseEmbedding ] ):
342342 def init_embedding (self , model_name : str , cache_dir : str , ** kwargs : Any ) -> Bm42 :
343343 return Bm42 (
344344 model_name = model_name ,
0 commit comments