Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from typing import Any, Iterable, Optional, Sequence, Type
from typing import Iterable, Any, Sequence, Optional, Type


import numpy as np

from fastembed.common import OnnxProvider
from fastembed.rerank.cross_encoder.onnx_text_model import (
OnnxCrossEncoderModel,
TextRerankerWorker,
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
from fastembed.common.utils import define_cache_dir
from loguru import logger

from fastembed.common import OnnxProvider
Expand All @@ -12,6 +21,7 @@
)
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase


supported_onnx_models = [
{
"model": "Xenova/ms-marco-MiniLM-L-6-v2",
Expand Down
31 changes: 14 additions & 17 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence
from typing import Sequence, Optional, Iterable, Any, Type

import numpy as np
from tokenizers import Encoding
Expand All @@ -18,9 +18,13 @@
from fastembed.parallel_processor import ParallelWorkerPool


class OnnxCrossEncoderModel(OnnxModel):
class OnnxCrossEncoderModel(OnnxModel[float]):
ONNX_OUTPUT_NAMES: Optional[list[str]] = None

@classmethod
def _get_worker_class(cls) -> Type["TextRerankerWorker"]:
raise NotImplementedError("Subclasses must implement this method")

def _load_onnx_model(
self,
model_dir: Path,
Expand All @@ -40,10 +44,8 @@ def _load_onnx_model(
)
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)

def tokenize(
self, pairs: Iterable[tuple[str, str]], **kwargs: Any
) -> list[Encoding]:
return self.tokenizer.encode_batch([pair for pair in pairs], **kwargs)
def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]:
return self.tokenizer.encode_batch(pairs, **kwargs)

def _build_onnx_input(self, tokenized_input):
inputs = {
Expand All @@ -59,10 +61,8 @@ def _build_onnx_input(self, tokenized_input):
)
return inputs

def onnx_embed(
self, query: str, documents: list[str], **kwargs: Any
) -> OnnxOutputContext:
pairs = ((query, doc) for doc in documents)
def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
pairs = [(query, doc) for doc in documents]
return self.onnx_embed_pairs(pairs, **kwargs)

def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any):
Expand All @@ -72,15 +72,15 @@ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any):
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
relevant_output = outputs[0]
scores = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores.tolist())
return OnnxOutputContext(model_output=scores)

def _rerank_documents(
self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
) -> Iterable[float]:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self.onnx_embed(query, batch, **kwargs).model_output
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))

def _rerank_pairs(
self,
Expand All @@ -96,9 +96,6 @@ def _rerank_pairs(
) -> Iterable[float]:
is_small = False

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

if isinstance(pairs, tuple):
pairs = [pairs]
is_small = True
Expand All @@ -111,7 +108,7 @@ def _rerank_pairs(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(pairs, batch_size):
yield from self.onnx_embed_pairs(batch, **kwargs).model_output
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
else:
if parallel == 0:
parallel = os.cpu_count()
Expand All @@ -138,7 +135,7 @@ def _rerank_pairs(
self.onnx_embed_pairs(batch, **kwargs)
)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
return output.model_output

def _preprocess_onnx_input(
Expand Down
4 changes: 2 additions & 2 deletions fastembed/rerank/cross_encoder/text_cross_encoder_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Iterable, Optional, Any

from fastembed.common.model_management import ModelManagement

Expand All @@ -23,7 +23,7 @@ def rerank(
batch_size: int = 64,
**kwargs,
) -> Iterable[float]:
"""Reranks a list of documents given a query.
"""Rerank a list of documents given a query.

Args:
query (str): The query to rerank the documents.
Expand Down
Loading