@@ -88,9 +88,11 @@ def _process_logits(self, input_ids, scores):
88
88
elif self .length_penalty is not None :
89
89
tokens_past = self .current_tokens - self .length_penalty [0 ]
90
90
if tokens_past > 0 :
91
- scores [:, self .eos_token_id ] = scores [:, self .eos_token_id ] * pow (
92
- self .length_penalty [1 ], tokens_past
93
- )
91
+ eos_scores = scores [:, self .eos_token_id ]
92
+ # To support negative logits we compute the penalty of the
93
+ # absolute value and add to the original logit
94
+ scores [:, self .eos_token_id ] = eos_scores + torch .abs (eos_scores ) * (
95
+ pow (self .length_penalty [1 ], tokens_past ) - 1 )
94
96
self .current_tokens += 1
95
97
96
98
# Apply repetition penalty if applicable
@@ -246,9 +248,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
246
248
elif length_penalty is not None :
247
249
tokens_past = current_tokens - length_penalty [0 ]
248
250
if tokens_past > 0 :
249
- scores [idx , self .eos_token_id ] = scores [idx , self .eos_token_id ] * pow (
250
- length_penalty [1 ], tokens_past
251
- )
251
+ eos_scores = scores [idx , self .eos_token_id ]
252
+ # To support negative logits we compute the penalty of the
253
+ # absolute value and add to the original logit
254
+ scores [idx , self .eos_token_id ] = eos_scores + torch .abs (eos_scores ) * (
255
+ pow (length_penalty [1 ], tokens_past ) - 1 )
252
256
self .current_tokens [idx ] += 1
253
257
# Apply the repetition penalty if we have one
254
258
if self .repetition_processor is not None :
0 commit comments