Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Iterable, Any, Sequence, Optional
from typing import Any, Iterable, Optional, Sequence, Type

from loguru import logger

from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_model import OnnxCrossEncoderModel
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir
from fastembed.rerank.cross_encoder.onnx_text_model import (
OnnxCrossEncoderModel,
TextRerankerWorker,
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase

supported_onnx_models = [
{
Expand Down Expand Up @@ -91,7 +95,7 @@ def __init__(
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
"""
Args:
Expand Down Expand Up @@ -138,7 +142,9 @@ def __init__(
self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(
self.model_description, self.cache_dir, local_files_only=self._local_files_only
self.model_description,
self.cache_dir,
local_files_only=self._local_files_only,
)

if not self.lazy_load:
Expand All @@ -159,7 +165,7 @@ def rerank(
query: str,
documents: Iterable[str],
batch_size: int = 64,
**kwargs,
**kwargs: Any,
) -> Iterable[float]:
"""Reranks documents based on their relevance to a given query.

Expand All @@ -175,3 +181,44 @@ def rerank(
yield from self._rerank_documents(
query=query, documents=documents, batch_size=batch_size, **kwargs
)

def rerank_pairs(
self,
pairs: Iterable[tuple[str, str]],
batch_size: int = 64,
parallel: Optional[int] = None,
**kwargs: Any,
) -> Iterable[float]:
yield from self._rerank_pairs(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
pairs=pairs,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

@classmethod
def _get_worker_class(cls) -> Type[TextRerankerWorker]:
return TextCrossEncoderWorker

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
return (float(elem) for elem in output.model_output)


class TextCrossEncoderWorker(TextRerankerWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs,
) -> OnnxTextCrossEncoder:
return OnnxTextCrossEncoder(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
112 changes: 96 additions & 16 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
from typing import Sequence, Optional, Iterable
import os
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence, Type

import numpy as np
from tokenizers import Encoding

from fastembed.common.onnx_model import OnnxModel, OnnxProvider, OnnxOutputContext
from fastembed.common.onnx_model import (
EmbeddingWorker,
OnnxModel,
OnnxOutputContext,
OnnxProvider,
)
from fastembed.common.preprocessor_utils import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool


class OnnxCrossEncoderModel(OnnxModel):
class OnnxCrossEncoderModel(OnnxModel[float]):
ONNX_OUTPUT_NAMES: Optional[list[str]] = None

@classmethod
def _get_worker_class(cls) -> Type["TextRerankerWorker"]:
raise NotImplementedError("Subclasses must implement this method")

def _load_onnx_model(
self,
model_dir: Path,
Expand All @@ -31,40 +43,108 @@ def _load_onnx_model(
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)

def tokenize(self, query: str, documents: list[str], **kwargs) -> list[Encoding]:
return self.tokenizer.encode_batch([(query, doc) for doc in documents])

def onnx_embed(self, query: str, documents: list[str], **kwargs) -> OnnxOutputContext:
tokenized_input = self.tokenize(query, documents, **kwargs)
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(pairs)

def _build_onnx_input(self, tokenized_input):
input_names = {node.name for node in self.model.get_inputs()}
inputs = {
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
"attention_mask": np.array(
[enc.attention_mask for enc in tokenized_input], dtype=np.int64
),
}
input_names = {node.name for node in self.model.get_inputs()}
if "token_type_ids" in input_names:
inputs["token_type_ids"] = np.array(
[enc.type_ids for enc in tokenized_input], dtype=np.int64
)
if "attention_mask" in input_names:
inputs["attention_mask"] = np.array(
[enc.attention_mask for enc in tokenized_input], dtype=np.int64
)
return inputs

def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
pairs = [(query, doc) for doc in documents]
return self.onnx_embed_pairs(pairs, **kwargs)

def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
tokenized_input = self.tokenize(pairs, **kwargs)
inputs = self._build_onnx_input(tokenized_input)
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
return OnnxOutputContext(model_output=outputs[0][:, 0].tolist())
relevant_output = outputs[0]
scores = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores)

def _rerank_documents(
self, query: str, documents: Iterable[str], batch_size: int, **kwargs
self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
) -> Iterable[float]:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self.onnx_embed(query, batch, **kwargs).model_output
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))

def _rerank_pairs(
self,
model_name: str,
cache_dir: str,
pairs: Iterable[tuple[str, str]],
batch_size: int,
parallel: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
**kwargs: Any,
) -> Iterable[float]:
is_small = False

if isinstance(pairs, tuple):
pairs = [pairs]
is_small = True

if isinstance(pairs, list):
if len(pairs) < batch_size:
is_small = True

if parallel is None or is_small:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(pairs, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
else:
if parallel == 0:
parallel = os.cpu_count()

start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"model_name": model_name,
"cache_dir": cache_dir,
"providers": providers,
**kwargs,
}

pool = ParallelWorkerPool(
num_workers=parallel or 1,
worker=self._get_worker_class(),
cuda=cuda,
device_ids=device_ids,
start_method=start_method,
)
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
yield from self._post_process_onnx_output(batch)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
raise NotImplementedError("Subclasses must implement this method")

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
) -> dict[str, np.ndarray]:
"""
Preprocess the onnx input.
"""
return onnx_input


class TextRerankerWorker(EmbeddingWorker):
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed_pairs(batch)
yield idx, onnx_output
41 changes: 37 additions & 4 deletions fastembed/rerank/cross_encoder/text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Iterable, Optional, Sequence, Type

from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase


class TextCrossEncoder(TextCrossEncoderBase):
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
**kwargs,
**kwargs: Any,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

Expand All @@ -72,7 +72,7 @@ def __init__(
)

def rerank(
self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs
self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs: Any
) -> Iterable[float]:
"""Rerank a list of documents based on a query.

Expand All @@ -85,3 +85,36 @@ def rerank(
Iterable of scores for each document
"""
yield from self.model.rerank(query, documents, batch_size=batch_size, **kwargs)

def rerank_pairs(
self,
pairs: Iterable[tuple[str, str]],
batch_size: int = 64,
parallel: Optional[int] = None,
**kwargs: Any,
) -> Iterable[float]:
"""
Rerank a list of query-document pairs.

Args:
pairs (Iterable[tuple[str, str]]): An iterable of tuples, where each tuple contains a query and a document
to be scored together.
batch_size (int, optional): The number of query-document pairs to process in a single batch. Defaults to 64.
parallel (Optional[int], optional): The number of parallel processes to use for reranking.
If None, parallelization is disabled. Defaults to None.
**kwargs (Any): Additional arguments to pass to the underlying reranking model.

Returns:
Iterable[float]: An iterable of scores corresponding to each query-document pair in the input.
Higher scores indicate a stronger match between the query and the document.

Example:
>>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2")
>>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")]
>>> scores = list(encoder.rerank_pairs(pairs))
>>> print(list(map(lambda x: round(x, 2), scores)))
[-1.24, -10.6]
"""
yield from self.model.rerank_pairs(
pairs, batch_size=batch_size, parallel=parallel, **kwargs
)
27 changes: 24 additions & 3 deletions fastembed/rerank/cross_encoder/text_cross_encoder_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional
from typing import Any, Iterable, Optional

from fastembed.common.model_management import ModelManagement

Expand All @@ -23,7 +23,7 @@ def rerank(
batch_size: int = 64,
**kwargs,
) -> Iterable[float]:
"""Reranks a list of documents given a query.
"""Rerank a list of documents given a query.

Args:
query (str): The query to rerank the documents.
Expand All @@ -32,6 +32,27 @@ def rerank(
**kwargs: Additional keyword argument to pass to the rerank method.

Yields:
Iterable[float]: The scores of reranked the documents.
Iterable[float]: The scores of the reranked the documents.
"""
raise NotImplementedError("This method should be overridden by subclasses")

def rerank_pairs(
self,
pairs: Iterable[tuple[str, str]],
batch_size: int = 64,
parallel: Optional[int] = None,
**kwargs: Any,
) -> Iterable[float]:
"""Rerank query-document pairs.
Args:
pairs (Iterable[tuple[str, str]]): Query-document pairs to rerank
batch_size (int): The batch size to use for reranking.
parallel: parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
**kwargs: Additional keyword argument to pass to the rerank method.
Yields:
Iterable[float]: Scores for each individual pair
"""
raise NotImplementedError("This method should be overridden by subclasses")
Loading
Loading