Skip to content

Commit 5531890

Browse files
authored
Support multiple sentences as input (#4)
1 parent cf0057b commit 5531890

File tree

3 files changed

+87
-23
lines changed

3 files changed

+87
-23
lines changed

lm_scorer/models/abc/base.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,79 @@ class LMScorer(ABC):
1010
def __init__(self, model_name: str, **kwargs: Any) -> None:
1111
self._build(model_name, kwargs)
1212

13+
@overload
1314
def sentence_score(
1415
self, text: str, log: bool = False, reduce: str = "prod"
1516
) -> float:
16-
log_probs, _, _ = self._tokens_log_prob(text)
17-
tlen = log_probs.shape[0]
17+
...
1818

19-
if reduce == "prod":
20-
score = log_probs.sum()
21-
elif reduce == "mean":
22-
score = log_probs.logsumexp(0) - math.log(tlen)
23-
elif reduce == "gmean":
24-
score = log_probs.mean(0)
25-
elif reduce == "hmean":
26-
score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
27-
else:
28-
raise ValueError("Unrecognized scoring strategy: %s" % reduce)
19+
@overload
20+
def sentence_score(
21+
self, text: List[str], log: bool = False, reduce: str = "prod"
22+
) -> List[float]:
23+
...
24+
25+
def sentence_score(
26+
self, text: Union[str, List[str]], log: bool = False, reduce: str = "prod",
27+
) -> Union[float, List[float]]:
28+
sentences = [text] if isinstance(text, str) else text
29+
if len(sentences) == 0:
30+
return []
31+
32+
outputs = self._tokens_log_prob(sentences)
33+
34+
scores = []
35+
for output in outputs:
36+
log_probs = output[0]
37+
tlen = log_probs.shape[0]
2938

30-
if not log:
31-
score = score.exp()
39+
if reduce == "prod":
40+
score = log_probs.sum()
41+
elif reduce == "mean":
42+
score = log_probs.logsumexp(0) - math.log(tlen)
43+
elif reduce == "gmean":
44+
score = log_probs.mean(0)
45+
elif reduce == "hmean":
46+
score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
47+
else:
48+
raise ValueError("Unrecognized scoring strategy: %s" % reduce)
3249

33-
return score.item()
50+
if not log:
51+
score = score.exp()
3452

53+
scores.append(score.item())
54+
55+
return scores[0] if isinstance(text, str) else scores
56+
57+
@overload
3558
def tokens_score(
3659
self, text: str, log: bool = False
3760
) -> Tuple[List[float], List[int], List[str]]:
38-
log_probs, ids, tokens = self._tokens_log_prob(text)
61+
...
3962

40-
scores = log_probs # type: torch.Tensor # type: ignore
41-
if not log:
42-
scores = scores.exp()
63+
@overload
64+
def tokens_score(
65+
self, text: List[str], log: bool = False
66+
) -> List[Tuple[List[float], List[int], List[str]]]:
67+
...
68+
69+
def tokens_score(
70+
self, text: Union[str, List[str]], log: bool = False
71+
) -> Union[
72+
Tuple[List[float], List[int], List[str]],
73+
List[Tuple[List[float], List[int], List[str]]],
74+
]:
75+
sentences = [text] if isinstance(text, str) else text
76+
if len(sentences) == 0:
77+
return []
78+
outputs = []
79+
for log_probs, ids, tokens in self._tokens_log_prob(sentences):
80+
scores = log_probs # type: torch.Tensor # type: ignore
81+
if not log:
82+
scores = scores.exp()
83+
outputs.append((scores.tolist(), ids.tolist(), tokens))
4384

44-
return scores.tolist(), ids.tolist(), tokens
85+
return outputs[0] if isinstance(text, str) else outputs
4586

4687
@classmethod
4788
def supported_model_names(cls) -> Iterable[str]:
@@ -53,8 +94,8 @@ def _build(self, model_name: str, options: Dict[str, Any]) -> None:
5394

5495
@abstractmethod
5596
def _tokens_log_prob(
56-
self, text: str
57-
) -> Tuple[torch.FloatTensor, torch.LongTensor, List[str]]:
97+
self, text: List[str]
98+
) -> List[Tuple[torch.FloatTensor, torch.LongTensor, List[str]]]:
5899
... # pragma: no cover
59100

60101
@classmethod

lm_scorer/models/gpt2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _build(self, model_name: str, options: Dict[str, Any]) -> None:
1919
self.model.to(options["device"])
2020

2121
# @overrides
22-
def _tokens_log_prob(
22+
def _tokens_log_prob_single_sentence(
2323
self, text: str
2424
) -> Tuple[torch.FloatTensor, torch.LongTensor, List[str]]:
2525
device = self.model.device
@@ -57,6 +57,11 @@ def _tokens_log_prob(
5757

5858
return log_probs[0], ids[0], tokens # type: ignore
5959

60+
def _tokens_log_prob(
61+
self, text: List[str]
62+
) -> List[Tuple[torch.FloatTensor, torch.LongTensor, List[str]]]:
63+
return list(map(self._tokens_log_prob_single_sentence, text))
64+
6065
# @overrides
6166
@classmethod
6267
def _supported_model_names(cls) -> Iterable[str]:

tests/models/test_gpt2.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ def should_work_on_an_empty_sentence():
5151
score = scorer.sentence_score("", reduce="hmean", log=True)
5252
assert score <= 0.0
5353

54+
def should_work_on_an_empty_list():
55+
assert scorer.sentence_score([]) == []
56+
57+
def should_give_same_results_independently_of_input_type():
58+
sentences = [
59+
"I have a big amount of money.",
60+
"This is the best day of my life.",
61+
"I think this game is easier than the one we played yesterday.",
62+
]
63+
64+
sentences_scores = scorer.sentence_score(sentences)
65+
66+
for i, sentence in enumerate(sentences):
67+
assert scorer.sentence_score(sentence) == sentences_scores[i]
68+
5469
# TODO: Test the various reducing strategies by mocking the _tokens_log_prob call.
5570

5671

@@ -70,6 +85,9 @@ def should_work_on_an_empty_sentence():
7085
assert len(tokens) == 1, tokens
7186
assert scores[0] <= 0.0
7287

88+
def should_work_on_an_empty_list():
89+
assert scorer.tokens_score([]) == []
90+
7391

7492
def describe_sentence_score_for_english():
7593
scorer = GPT2LMScorer("gpt2")

0 commit comments

Comments
 (0)