Skip to content

Commit ea85d72

Browse files
authored
Calculate the length penalty in the same way as the transformers library (#75)
Signed-off-by: Max de Bayser <[email protected]>
1 parent ce6d68b commit ea85d72

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

server/text_generation_server/utils/tokens.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ def _process_logits(self, input_ids, scores):
8888
elif self.length_penalty is not None:
8989
tokens_past = self.current_tokens - self.length_penalty[0]
9090
if tokens_past > 0:
91-
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
92-
self.length_penalty[1], tokens_past
93-
)
91+
eos_scores = scores[:, self.eos_token_id]
92+
# To support negative logits we compute the penalty of the
93+
# absolute value and add to the original logit
94+
scores[:, self.eos_token_id] = eos_scores + torch.abs(eos_scores) * (
95+
pow(self.length_penalty[1], tokens_past) - 1)
9496
self.current_tokens += 1
9597

9698
# Apply repetition penalty if applicable
@@ -246,9 +248,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
246248
elif length_penalty is not None:
247249
tokens_past = current_tokens - length_penalty[0]
248250
if tokens_past > 0:
249-
scores[idx, self.eos_token_id] = scores[idx, self.eos_token_id] * pow(
250-
length_penalty[1], tokens_past
251-
)
251+
eos_scores = scores[idx, self.eos_token_id]
252+
# To support negative logits we compute the penalty of the
253+
# absolute value and add to the original logit
254+
scores[idx, self.eos_token_id] = eos_scores + torch.abs(eos_scores) * (
255+
pow(length_penalty[1], tokens_past) - 1)
252256
self.current_tokens[idx] += 1
253257
# Apply the repetition penalty if we have one
254258
if self.repetition_processor is not None:

0 commit comments

Comments
 (0)