Skip to content

Commit 6efe06b

Browse files
authored
new: decouple colbert query and document tokenizer (#556)
1 parent 8872392 commit 6efe06b

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

fastembed/late_interaction/colbert.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from typing import Any, Iterable, Optional, Sequence, Type, Union
33

44
import numpy as np
5-
from tokenizers import Encoding
5+
from tokenizers import Encoding, Tokenizer
66

7+
from fastembed.common.preprocessor_utils import load_tokenizer
78
from fastembed.common.types import NumpyArray
89
from fastembed.common import OnnxProvider
910
from fastembed.common.onnx_model import OnnxOutputContext
@@ -87,23 +88,8 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->
8788
)
8889

8990
def _tokenize_query(self, query: str) -> list[Encoding]:
90-
assert self.tokenizer is not None
91-
encoded = self.tokenizer.encode_batch([query])
92-
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
93-
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
94-
prev_padding = None
95-
if self.tokenizer.padding:
96-
prev_padding = self.tokenizer.padding
97-
self.tokenizer.enable_padding(
98-
pad_token=self.MASK_TOKEN,
99-
pad_id=self.mask_token_id,
100-
length=self.MIN_QUERY_LENGTH,
101-
)
102-
encoded = self.tokenizer.encode_batch([query])
103-
if prev_padding is None:
104-
self.tokenizer.no_padding()
105-
else:
106-
self.tokenizer.enable_padding(**prev_padding)
91+
assert self.query_tokenizer is not None
92+
encoded = self.query_tokenizer.encode_batch([query])
10793
return encoded
10894

10995
def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
@@ -183,6 +169,8 @@ def __init__(
183169
self.pad_token_id: Optional[int] = None
184170
self.skip_list: set[int] = set()
185171

172+
self.query_tokenizer: Optional[Tokenizer] = None
173+
186174
if not self.lazy_load:
187175
self.load_onnx_model()
188176

@@ -195,6 +183,8 @@ def load_onnx_model(self) -> None:
195183
cuda=self.cuda,
196184
device_id=self.device_id,
197185
)
186+
self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir)
187+
198188
assert self.tokenizer is not None
199189
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
200190
self.pad_token_id = self.tokenizer.padding["pad_id"]
@@ -205,6 +195,12 @@ def load_onnx_model(self) -> None:
205195
current_max_length = self.tokenizer.truncation["max_length"]
206196
# ensure not to overflow after adding document-marker
207197
self.tokenizer.enable_truncation(max_length=current_max_length - 1)
198+
self.query_tokenizer.enable_truncation(max_length=current_max_length - 1)
199+
self.query_tokenizer.enable_padding(
200+
pad_token=self.MASK_TOKEN,
201+
pad_id=self.mask_token_id,
202+
length=self.MIN_QUERY_LENGTH,
203+
)
208204

209205
def embed(
210206
self,

0 commit comments

Comments
 (0)