44from typing import Any , Iterable , Optional , Sequence , Type , Union
55
66import numpy as np
7+ from numpy .typing import NDArray
78from tokenizers import Encoding
89
910from fastembed .common .types import NumpyArray , OnnxProvider
@@ -23,14 +24,14 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]:
2324 def _post_process_onnx_output (self , output : OnnxOutputContext ) -> Iterable [T ]:
2425 raise NotImplementedError ("Subclasses must implement this method" )
2526
26- def __init__ (self ):
27+ def __init__ (self ) -> None :
2728 super ().__init__ ()
2829 self .tokenizer = None
2930 self .special_token_to_id : dict [str , int ] = {}
3031
3132 def _preprocess_onnx_input (
3233 self , onnx_input : dict [str , NumpyArray ], ** kwargs : Any
33- ) -> dict [str , NumpyArray ]:
34+ ) -> dict [str , Union [ NumpyArray , NDArray [ np . int64 ]] ]:
3435 """
3536 Preprocess the onnx input.
3637 """
@@ -60,7 +61,7 @@ def load_onnx_model(self) -> None:
6061 raise NotImplementedError ("Subclasses must implement this method" )
6162
6263 def tokenize (self , documents : list [str ], ** kwargs : Any ) -> list [Encoding ]:
63- return self .tokenizer .encode_batch (documents )
64+ return self .tokenizer .encode_batch (documents ) # type: ignore
6465
6566 def onnx_embed (
6667 self ,
@@ -70,7 +71,7 @@ def onnx_embed(
7071 encoded = self .tokenize (documents , ** kwargs )
7172 input_ids = np .array ([e .ids for e in encoded ])
7273 attention_mask = np .array ([e .attention_mask for e in encoded ])
73- input_names = {node .name for node in self .model .get_inputs ()}
74+ input_names = {node .name for node in self .model .get_inputs ()} # type: ignore
7475 onnx_input : dict [str , NumpyArray ] = {
7576 "input_ids" : np .array (input_ids , dtype = np .int64 ),
7677 }
@@ -82,7 +83,7 @@ def onnx_embed(
8283 )
8384 onnx_input = self ._preprocess_onnx_input (onnx_input , ** kwargs )
8485
85- model_output = self .model .run (self .ONNX_OUTPUT_NAMES , onnx_input )
86+ model_output = self .model .run (self .ONNX_OUTPUT_NAMES , onnx_input ) # type: ignore
8687 return OnnxOutputContext (
8788 model_output = model_output [0 ],
8889 attention_mask = onnx_input .get ("attention_mask" , attention_mask ),
@@ -136,7 +137,7 @@ def _embed_documents(
136137 start_method = start_method ,
137138 )
138139 for batch in pool .ordered_map (iter_batch (documents , batch_size ), ** params ):
139- yield from self ._post_process_onnx_output (batch )
140+ yield from self ._post_process_onnx_output (batch ) # type: ignore
140141
141142
142143class TextEmbeddingWorker (EmbeddingWorker [T ]):
0 commit comments