Skip to content

Commit 5dc7374

Browse files
authored
fix: fix SentenceTransformerRerank init (#19756)
* fix: fix SentenceTransformerRerank init * style: fixed linting
1 parent 8e9395c commit 5dc7374

File tree

4 files changed

+659
-16
lines changed

4 files changed

+659
-16
lines changed

llama-index-integrations/postprocessor/llama-index-postprocessor-sbert-rerank/llama_index/postprocessor/sbert_rerank/base.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,22 @@
1818

1919

2020
class SentenceTransformerRerank(BaseNodePostprocessor):
21+
"""
22+
HuggingFace class for cross encoding two sentences/texts.
23+
24+
Args:
25+
model (str): A model name from Hugging Face Hub that can be loaded with AutoModel, or a path to a local model.
26+
device (str, optional): Device (like “cuda”, “cpu”, “mps”, “npu”) that should be used for computation.
27+
If None, checks if a GPU can be used.
28+
cache_folder (str, Path, optional): Path to the folder where cached files are stored. Defaults to None.
29+
top_n (int): Number of nodes to return sorted by score. Defaults to 2.
30+
keep_retrieval_score (bool, optional): Whether to keep the retrieval score in metadata. Defaults to False.
31+
cross_encoder_kwargs (dict, optional): Additional keyword arguments for CrossEncoder initialization. Defaults to None.
32+
33+
"""
34+
2135
model: str = Field(description="Sentence transformer model name.")
2236
top_n: int = Field(description="Number of nodes to return sorted by score.")
23-
device: str = Field(
24-
default="cpu",
25-
description="Device to use for sentence transformer.",
26-
)
2737
keep_retrieval_score: bool = Field(
2838
default=False,
2939
description="Whether to keep the retrieval score in metadata.",
@@ -34,14 +44,15 @@ class SentenceTransformerRerank(BaseNodePostprocessor):
3444
"device and model should not be included here.",
3545
)
3646
_model: Any = PrivateAttr()
47+
_device: str = PrivateAttr()
3748

3849
def __init__(
3950
self,
40-
top_n: int = 2,
4151
model: str = "cross-encoder/stsb-distilroberta-base",
4252
device: Optional[str] = None,
53+
cache_folder: Optional[Union[str, Path]] = None,
54+
top_n: int = 2,
4355
keep_retrieval_score: Optional[bool] = False,
44-
cache_dir: Optional[Union[str, Path]] = None,
4556
cross_encoder_kwargs: Optional[dict] = None,
4657
):
4758
try:
@@ -74,11 +85,13 @@ def __init__(
7485
# Explicit arguments from the constructor take precedence over kwargs
7586
resolved_device = infer_torch_device() if device is None else device
7687
init_kwargs["device"] = resolved_device
77-
if cache_dir:
78-
init_kwargs["cache_dir"] = cache_dir
88+
self._device = resolved_device
89+
90+
if cache_folder:
91+
init_kwargs["cache_folder"] = cache_folder
7992

8093
self._model = CrossEncoder(
81-
model_name=model,
94+
model_name_or_path=model,
8295
**init_kwargs,
8396
)
8497

llama-index-integrations/postprocessor/llama-index-postprocessor-sbert-rerank/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ dev = [
2222
"codespell[toml]>=v2.2.6",
2323
"diff-cover>=9.2.0",
2424
"pytest-cov>=6.1.1",
25+
"sentence-transformers>=5.1.0",
2526
]
2627

2728
[project]
2829
name = "llama-index-postprocessor-sbert-rerank"
29-
version = "0.4.0"
30+
version = "0.4.1"
3031
description = "llama-index postprocessor sbert rerank integration"
3132
authors = [{name = "Your Name", email = "[email protected]"}]
3233
requires-python = ">=3.9,<4.0"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from llama_index.core.postprocessor.types import BaseNodePostprocessor
22
from llama_index.postprocessor.sbert_rerank import SentenceTransformerRerank
3+
from llama_index.core.utils import infer_torch_device
34

45

56
def test_class():
67
names_of_base_classes = [b.__name__ for b in SentenceTransformerRerank.__mro__]
78
assert BaseNodePostprocessor.__name__ in names_of_base_classes
9+
10+
11+
def test_init():
12+
assert SentenceTransformerRerank()
13+
14+
15+
def test_device():
16+
device = infer_torch_device() or "cpu"
17+
assert SentenceTransformerRerank()._device == device

0 commit comments

Comments
 (0)