Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@
},
"model_file": "model.onnx",
},
{
"model": "akshayballal/colpali-v1.2-merged",
"dim": 128,
"description": "",
"license": "mit",
"size_in_GB": 6.08,
"sources": {
"hf": "akshayballal/colpali-v1.2-merged-onnx",
},
"additional_files": ["model.onnx_data"],
"model_file": "model.onnx",
},
]


Expand Down
47 changes: 40 additions & 7 deletions fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Iterable, Any, Sequence, Optional
from typing import Iterable, Any, Sequence, Optional, Self, Type

from loguru import logger

import numpy as np
from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_model import OnnxCrossEncoderModel
from fastembed.rerank.cross_encoder.onnx_text_model import OnnxCrossEncoderModel, TextRerankerWorker
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.common.utils import define_cache_dir
from fastembed.common.onnx_model import OnnxOutputContext

supported_onnx_models = [
{
Expand Down Expand Up @@ -82,7 +83,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_onnx_models

def __init__(
self,
self: Self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
Expand All @@ -91,7 +92,7 @@ def __init__(
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
"""
Args:
Expand Down Expand Up @@ -155,11 +156,11 @@ def load_onnx_model(self) -> None:
)

def rerank(
self,
self: Self,
query: str,
documents: Iterable[str],
batch_size: int = 64,
**kwargs,
**kwargs: Any,
) -> Iterable[float]:
"""Reranks documents based on their relevance to a given query.

Expand All @@ -175,3 +176,35 @@ def rerank(
yield from self._rerank_documents(
query=query, documents=documents, batch_size=batch_size, **kwargs
)

def rerank_pairs(
self: Self,
pairs: Iterable[tuple[str, str]],
batch_size: int = 64,
**kwargs: Any,
) -> Iterable[float]:
yield from self._rerank_pairs(
pairs=pairs, batch_size=batch_size, **kwargs
)
Comment on lines 185 to 202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def rerank_pairs(
self,
pairs: Iterable[tuple[str, str]],
batch_size: int = 64,
**kwargs: Any,
) -> Iterable[float]:
yield from self._rerank_pairs(
model_name=self._model_dir,
cache_dir=self.cache_dir,
pairs=pairs,
batch_size=batch_size,
**kwargs,
)
def rerank_pairs(
self,
pairs: Iterable[tuple[str, str]],
batch_size: int = 64,
parallel: Optional[int] = None,
**kwargs: Any,
) -> Iterable[float]:
yield from self._rerank_pairs(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
pairs=pairs,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)


@classmethod
def _get_worker_class(cls) -> Type[TextRerankerWorker]:
return TextCrossEncoderWorker

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
return output.model_output


class TextCrossEncoderWorker(TextRerankerWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs,
) -> OnnxTextCrossEncoder:
return OnnxTextCrossEncoder(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
92 changes: 80 additions & 12 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from typing import Sequence, Optional, Iterable
from typing import Sequence, Optional, Iterable, Any, Self
from pathlib import Path

import os
import numpy as np
from tokenizers import Encoding
from multiprocessing import get_all_start_methods

from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext
from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext, EmbeddingWorker, T
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool


class OnnxCrossEncoderModel(OnnxModel):
ONNX_OUTPUT_NAMES: Optional[list[str]] = None

def _load_onnx_model(
self,
self: Self,
model_dir: Path,
model_file: str,
threads: Optional[int],
Expand All @@ -31,12 +33,10 @@ def _load_onnx_model(
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)

def tokenize(self, query: str, documents: list[str], **kwargs) -> list[Encoding]:
return self.tokenizer.encode_batch([(query, doc) for doc in documents])

def onnx_embed(self, query: str, documents: list[str], **kwargs) -> OnnxOutputContext:
tokenized_input = self.tokenize(query, documents, **kwargs)
def tokenize(self: Self, pairs: Iterable[tuple[str, str]], **kwargs: Any) -> list[Encoding]:
return self.tokenizer.encode_batch([pair for pair in pairs], **kwargs)

def _build_onnx_input(self, tokenized_input):
inputs = {
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
"attention_mask": np.array(
Expand All @@ -48,23 +48,91 @@ def onnx_embed(self, query: str, documents: list[str], **kwargs) -> OnnxOutputCo
inputs["token_type_ids"] = np.array(
[enc.type_ids for enc in tokenized_input], dtype=np.int64
)
return inputs

def onnx_embed(self: Self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
pairs = ((query, doc) for doc in documents)
return self.onnx_embed_pairs(pairs, **kwargs)

def onnx_embed_pairs(self: Self, pairs: Iterable[tuple[str, str]], **kwargs: Any):
tokenized_input = self.tokenize(pairs, **kwargs)
inputs = self._build_onnx_input(tokenized_input)
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
return OnnxOutputContext(model_output=outputs[0][:, 0].tolist())
relevant_output = outputs[0]
scores = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores.tolist())

def _rerank_documents(
self, query: str, documents: Iterable[str], batch_size: int, **kwargs
self: Self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
) -> Iterable[float]:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self.onnx_embed(query, batch, **kwargs).model_output

def _rerank_pairs(
self: Self, pairs: Iterable[tuple[str,str]], batch_size: int, parallel: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
**kwargs: Any,
) -> Iterable[float]:
is_small = False

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

if isinstance(pairs, tuple):
pairs = [pairs]
is_small = True

if isinstance(pairs, list):
if len(pairs) < batch_size:
is_small = True

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

if parallel is None or is_small:
for batch in iter_batch(pairs, batch_size):
yield from self.onnx_embed_pairs(batch, **kwargs).model_output
else:
if parallel == 0:
parallel = os.cpu_count()

start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"model_name": self.model_description['model'],
"cache_dir": self.cache_dir,
"providers": providers,
**kwargs,
}

pool = ParallelWorkerPool(
num_workers=parallel or 1,
worker=self._get_worker_class(),
cuda=cuda,
device_ids=device_ids,
start_method=start_method,
)
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
yield from self._post_process_onnx_output(batch)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
raise NotImplementedError("Subclasses must implement this method")

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
self: Self, onnx_input: dict[str, np.ndarray], **kwargs: Any
) -> dict[str, np.ndarray]:
"""
Preprocess the onnx input.
"""
return onnx_input


class TextRerankerWorker(EmbeddingWorker):
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed_pairs(batch)
yield idx, onnx_output
12 changes: 8 additions & 4 deletions fastembed/rerank/cross_encoder/text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional, Sequence, Type
from typing import Any, Iterable, Optional, Sequence, Type, Self

from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
Expand Down Expand Up @@ -39,15 +39,15 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
return result

def __init__(
self,
self: Self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
**kwargs,
**kwargs : Any,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

Expand All @@ -72,7 +72,7 @@ def __init__(
)

def rerank(
self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs
self: Self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs: Any
) -> Iterable[float]:
"""Rerank a list of documents based on a query.

Expand All @@ -85,3 +85,7 @@ def rerank(
Iterable of scores for each document
"""
yield from self.model.rerank(query, documents, batch_size=batch_size, **kwargs)

def rerank_pairs(self: Self, pairs: Iterable[tuple[str]], batch_size: int = 64,
**kwargs: Any,) -> Iterable[float]:
yield from self.model.rerank_pairs(pairs, batch_size=batch_size, **kwargs)
14 changes: 11 additions & 3 deletions fastembed/rerank/cross_encoder/text_cross_encoder_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Iterable, Optional
from abc import abstractclassmethod
from typing import Iterable, Optional, Any, Self

from fastembed.common.model_management import ModelManagement


class TextCrossEncoderBase(ModelManagement):
def __init__(
self,
self: Self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
Expand All @@ -16,8 +17,9 @@ def __init__(
self.threads = threads
self._local_files_only = kwargs.pop("local_files_only", False)

@classmethod
def rerank(
self,
self: Self,
query: str,
documents: Iterable[str],
batch_size: int = 64,
Expand All @@ -35,3 +37,9 @@ def rerank(
Iterable[float]: The scores of reranked the documents.
"""
raise NotImplementedError("This method should be overridden by subclasses")

@classmethod
def rerank_pairs(self: Self, pairs: Iterable[tuple[str]], batch_size: int = 64,
**kwargs: Any,) -> Iterable[float]:
raise NotImplementedError("This method should be overridden by subclasses")

12 changes: 12 additions & 0 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@
},
"model_file": "onnx/model.onnx",
},
{
"model": "akshayballal/colpali-v1.2-merged",
"dim": 128,
"description": "",
"license": "mit",
"size_in_GB": 6.08,
"sources": {
"hf": "akshayballal/colpali-v1.2-merged-onnx",
},
"additional_files": ["model.onnx_data"],
"model_file": "model.onnx",
},
]


Expand Down
5 changes: 5 additions & 0 deletions fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def onnx_embed(
onnx_input["token_type_ids"] = np.array(
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
)
if "pixel_values" in input_names:
onnx_input["pixel_values"] = np.zeros(
(np.array(input_ids, dtype=np.int64).shape[0], 3, 448, 448), dtype=np.float32
)

onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)

Expand Down Expand Up @@ -116,6 +120,7 @@ def _embed_documents(
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed(batch))

else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down