@@ -10,38 +10,79 @@ class LMScorer(ABC):
1010 def __init__ (self , model_name : str , ** kwargs : Any ) -> None :
1111 self ._build (model_name , kwargs )
1212
13+ @overload
1314 def sentence_score (
1415 self , text : str , log : bool = False , reduce : str = "prod"
1516 ) -> float :
16- log_probs , _ , _ = self ._tokens_log_prob (text )
17- tlen = log_probs .shape [0 ]
17+ ...
1818
19- if reduce == "prod" :
20- score = log_probs .sum ()
21- elif reduce == "mean" :
22- score = log_probs .logsumexp (0 ) - math .log (tlen )
23- elif reduce == "gmean" :
24- score = log_probs .mean (0 )
25- elif reduce == "hmean" :
26- score = log_probs .neg ().logsumexp (0 ).neg () + math .log (tlen )
27- else :
28- raise ValueError ("Unrecognized scoring strategy: %s" % reduce )
19+ @overload
20+ def sentence_score (
21+ self , text : List [str ], log : bool = False , reduce : str = "prod"
22+ ) -> List [float ]:
23+ ...
24+
25+ def sentence_score (
26+ self , text : Union [str , List [str ]], log : bool = False , reduce : str = "prod" ,
27+ ) -> Union [float , List [float ]]:
28+ sentences = [text ] if isinstance (text , str ) else text
29+ if len (sentences ) == 0 :
30+ return []
31+
32+ outputs = self ._tokens_log_prob (sentences )
33+
34+ scores = []
35+ for output in outputs :
36+ log_probs = output [0 ]
37+ tlen = log_probs .shape [0 ]
2938
30- if not log :
31- score = score .exp ()
39+ if reduce == "prod" :
40+ score = log_probs .sum ()
41+ elif reduce == "mean" :
42+ score = log_probs .logsumexp (0 ) - math .log (tlen )
43+ elif reduce == "gmean" :
44+ score = log_probs .mean (0 )
45+ elif reduce == "hmean" :
46+ score = log_probs .neg ().logsumexp (0 ).neg () + math .log (tlen )
47+ else :
48+ raise ValueError ("Unrecognized scoring strategy: %s" % reduce )
3249
33- return score .item ()
50+ if not log :
51+ score = score .exp ()
3452
53+ scores .append (score .item ())
54+
55+ return scores [0 ] if isinstance (text , str ) else scores
56+
57+ @overload
3558 def tokens_score (
3659 self , text : str , log : bool = False
3760 ) -> Tuple [List [float ], List [int ], List [str ]]:
38- log_probs , ids , tokens = self . _tokens_log_prob ( text )
61+ ...
3962
40- scores = log_probs # type: torch.Tensor # type: ignore
41- if not log :
42- scores = scores .exp ()
63+ @overload
64+ def tokens_score (
65+ self , text : List [str ], log : bool = False
66+ ) -> List [Tuple [List [float ], List [int ], List [str ]]]:
67+ ...
68+
69+ def tokens_score (
70+ self , text : Union [str , List [str ]], log : bool = False
71+ ) -> Union [
72+ Tuple [List [float ], List [int ], List [str ]],
73+ List [Tuple [List [float ], List [int ], List [str ]]],
74+ ]:
75+ sentences = [text ] if isinstance (text , str ) else text
76+ if len (sentences ) == 0 :
77+ return []
78+ outputs = []
79+ for log_probs , ids , tokens in self ._tokens_log_prob (sentences ):
80+ scores = log_probs # type: torch.Tensor # type: ignore
81+ if not log :
82+ scores = scores .exp ()
83+ outputs .append ((scores .tolist (), ids .tolist (), tokens ))
4384
44- return scores . tolist (), ids . tolist (), tokens
85+ return outputs [ 0 ] if isinstance ( text , str ) else outputs
4586
4687 @classmethod
4788 def supported_model_names (cls ) -> Iterable [str ]:
@@ -53,8 +94,8 @@ def _build(self, model_name: str, options: Dict[str, Any]) -> None:
5394
5495 @abstractmethod
5596 def _tokens_log_prob (
56- self , text : str
57- ) -> Tuple [torch .FloatTensor , torch .LongTensor , List [str ]]:
97+ self , text : List [ str ]
98+ ) -> List [ Tuple [torch .FloatTensor , torch .LongTensor , List [str ] ]]:
5899 ... # pragma: no cover
59100
60101 @classmethod
0 commit comments