22from typing import Any , Iterable , Optional , Sequence , Type , Union
33
44import numpy as np
5- from tokenizers import Encoding
5+ from tokenizers import Encoding , Tokenizer
66
7+ from fastembed .common .preprocessor_utils import load_tokenizer
78from fastembed .common .types import NumpyArray
89from fastembed .common import OnnxProvider
910from fastembed .common .onnx_model import OnnxOutputContext
@@ -87,23 +88,8 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->
8788 )
8889
8990 def _tokenize_query (self , query : str ) -> list [Encoding ]:
90- assert self .tokenizer is not None
91- encoded = self .tokenizer .encode_batch ([query ])
92- # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
93- if len (encoded [0 ].ids ) < self .MIN_QUERY_LENGTH :
94- prev_padding = None
95- if self .tokenizer .padding :
96- prev_padding = self .tokenizer .padding
97- self .tokenizer .enable_padding (
98- pad_token = self .MASK_TOKEN ,
99- pad_id = self .mask_token_id ,
100- length = self .MIN_QUERY_LENGTH ,
101- )
102- encoded = self .tokenizer .encode_batch ([query ])
103- if prev_padding is None :
104- self .tokenizer .no_padding ()
105- else :
106- self .tokenizer .enable_padding (** prev_padding )
91+ assert self .query_tokenizer is not None
92+ encoded = self .query_tokenizer .encode_batch ([query ])
10793 return encoded
10894
10995 def _tokenize_documents (self , documents : list [str ]) -> list [Encoding ]:
@@ -183,6 +169,8 @@ def __init__(
183169 self .pad_token_id : Optional [int ] = None
184170 self .skip_list : set [int ] = set ()
185171
172+ self .query_tokenizer : Optional [Tokenizer ] = None
173+
186174 if not self .lazy_load :
187175 self .load_onnx_model ()
188176
@@ -195,6 +183,8 @@ def load_onnx_model(self) -> None:
195183 cuda = self .cuda ,
196184 device_id = self .device_id ,
197185 )
186+ self .query_tokenizer , _ = load_tokenizer (model_dir = self ._model_dir )
187+
198188 assert self .tokenizer is not None
199189 self .mask_token_id = self .special_token_to_id [self .MASK_TOKEN ]
200190 self .pad_token_id = self .tokenizer .padding ["pad_id" ]
@@ -205,6 +195,12 @@ def load_onnx_model(self) -> None:
205195 current_max_length = self .tokenizer .truncation ["max_length" ]
206196 # ensure not to overflow after adding document-marker
207197 self .tokenizer .enable_truncation (max_length = current_max_length - 1 )
198+ self .query_tokenizer .enable_truncation (max_length = current_max_length - 1 )
199+ self .query_tokenizer .enable_padding (
200+ pad_token = self .MASK_TOKEN ,
201+ pad_id = self .mask_token_id ,
202+ length = self .MIN_QUERY_LENGTH ,
203+ )
208204
209205 def embed (
210206 self ,
0 commit comments