1818
1919
2020class SentenceTransformerRerank (BaseNodePostprocessor ):
21+ """
22+ HuggingFace class for cross encoding two sentences/texts.
23+
24+ Args:
25+ model (str): A model name from Hugging Face Hub that can be loaded with AutoModel, or a path to a local model.
26+ device (str, optional): Device (like “cuda”, “cpu”, “mps”, “npu”) that should be used for computation.
27+ If None, checks if a GPU can be used.
28+ cache_folder (str, Path, optional): Path to the folder where cached files are stored. Defaults to None.
29+ top_n (int): Number of nodes to return sorted by score. Defaults to 2.
30+ keep_retrieval_score (bool, optional): Whether to keep the retrieval score in metadata. Defaults to False.
31+ cross_encoder_kwargs (dict, optional): Additional keyword arguments for CrossEncoder initialization. Defaults to None.
32+
33+ """
34+
2135 model : str = Field (description = "Sentence transformer model name." )
2236 top_n : int = Field (description = "Number of nodes to return sorted by score." )
23- device : str = Field (
24- default = "cpu" ,
25- description = "Device to use for sentence transformer." ,
26- )
2737 keep_retrieval_score : bool = Field (
2838 default = False ,
2939 description = "Whether to keep the retrieval score in metadata." ,
@@ -34,14 +44,15 @@ class SentenceTransformerRerank(BaseNodePostprocessor):
3444 "device and model should not be included here." ,
3545 )
3646 _model : Any = PrivateAttr ()
47+ _device : str = PrivateAttr ()
3748
3849 def __init__ (
3950 self ,
40- top_n : int = 2 ,
4151 model : str = "cross-encoder/stsb-distilroberta-base" ,
4252 device : Optional [str ] = None ,
53+ cache_folder : Optional [Union [str , Path ]] = None ,
54+ top_n : int = 2 ,
4355 keep_retrieval_score : Optional [bool ] = False ,
44- cache_dir : Optional [Union [str , Path ]] = None ,
4556 cross_encoder_kwargs : Optional [dict ] = None ,
4657 ):
4758 try :
@@ -74,11 +85,13 @@ def __init__(
7485 # Explicit arguments from the constructor take precedence over kwargs
7586 resolved_device = infer_torch_device () if device is None else device
7687 init_kwargs ["device" ] = resolved_device
77- if cache_dir :
78- init_kwargs ["cache_dir" ] = cache_dir
88+ self ._device = resolved_device
89+
90+ if cache_folder :
91+ init_kwargs ["cache_folder" ] = cache_folder
7992
8093 self ._model = CrossEncoder (
81- model_name = model ,
94+ model_name_or_path = model ,
8295 ** init_kwargs ,
8396 )
8497
0 commit comments