Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from tokenizers import Encoding
from tokenizers import Encoding, Tokenizer

from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.types import NumpyArray
from fastembed.common import OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
Expand Down Expand Up @@ -87,23 +88,8 @@ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) ->
)

def _tokenize_query(self, query: str) -> list[Encoding]:
assert self.tokenizer is not None
encoded = self.tokenizer.encode_batch([query])
# colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance
if len(encoded[0].ids) < self.MIN_QUERY_LENGTH:
prev_padding = None
if self.tokenizer.padding:
prev_padding = self.tokenizer.padding
self.tokenizer.enable_padding(
pad_token=self.MASK_TOKEN,
pad_id=self.mask_token_id,
length=self.MIN_QUERY_LENGTH,
)
encoded = self.tokenizer.encode_batch([query])
if prev_padding is None:
self.tokenizer.no_padding()
else:
self.tokenizer.enable_padding(**prev_padding)
assert self.query_tokenizer is not None
encoded = self.query_tokenizer.encode_batch([query])
return encoded

def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
Expand Down Expand Up @@ -183,6 +169,8 @@ def __init__(
self.pad_token_id: Optional[int] = None
self.skip_list: set[int] = set()

self.query_tokenizer: Optional[Tokenizer] = None

if not self.lazy_load:
self.load_onnx_model()

Expand All @@ -195,6 +183,8 @@ def load_onnx_model(self) -> None:
cuda=self.cuda,
device_id=self.device_id,
)
self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir)

assert self.tokenizer is not None
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
self.pad_token_id = self.tokenizer.padding["pad_id"]
Expand All @@ -205,6 +195,12 @@ def load_onnx_model(self) -> None:
current_max_length = self.tokenizer.truncation["max_length"]
# ensure not to overflow after adding document-marker
self.tokenizer.enable_truncation(max_length=current_max_length - 1)
self.query_tokenizer.enable_truncation(max_length=current_max_length - 1)
self.query_tokenizer.enable_padding(
pad_token=self.MASK_TOKEN,
pad_id=self.mask_token_id,
length=self.MIN_QUERY_LENGTH,
)

def embed(
self,
Expand Down
Loading