11
11
LogitsProcessor ,
12
12
TemperatureLogitsWarper ,
13
13
TopKLogitsWarper ,
14
- # TopPLogitsWarper,
14
+ TopPLogitsWarper ,
15
15
#TypicalLogitsWarper,
16
16
)
17
17
@@ -419,46 +419,7 @@ def filter(self, indices):
419
419
return None
420
420
421
421
422
-
423
-
424
- # These are fixed versions of the classese in transformers. Remove them after upgrading to transformers >= 4.30.
425
- # See https://github.com/huggingface/transformers/pull/24111
426
- class TopPLogitsWarper (LogitsWarper ):
427
- """
428
- [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
429
-
430
- Args:
431
- top_p (`float`):
432
- If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
433
- higher are kept for generation.
434
- filter_value (`float`, *optional*, defaults to `-float("Inf")`):
435
- All filtered values will be set to this float value.
436
- min_tokens_to_keep (`int`, *optional*, defaults to 1):
437
- Minimum number of tokens that cannot be filtered.
438
- """
439
-
440
- def __init__ (self , top_p : float , filter_value : float = - float ("Inf" ), min_tokens_to_keep : int = 1 ):
441
- top_p = float (top_p )
442
- if top_p < 0 or top_p > 1.0 :
443
- raise ValueError (f"`top_p` has to be a float > 0 and < 1, but is { top_p } " )
444
-
445
- self .top_p = top_p
446
- self .filter_value = filter_value
447
- self .min_tokens_to_keep = min_tokens_to_keep
448
-
449
- def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
450
- sorted_logits , sorted_indices = torch .sort (scores , descending = False )
451
- cumulative_probs = sorted_logits .softmax (dim = - 1 ).cumsum (dim = - 1 )
452
-
453
- # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
454
- sorted_indices_to_remove = cumulative_probs <= (1 - self .top_p )
455
- sorted_indices_to_remove [..., - self .min_tokens_to_keep :] = 0
456
-
457
- # scatter sorted tensors to original indexing
458
- indices_to_remove = sorted_indices_to_remove .scatter (1 , sorted_indices , sorted_indices_to_remove )
459
- scores = scores .masked_fill (indices_to_remove , self .filter_value )
460
- return scores
461
-
422
+ # This is a fixed version of the class in transformers. Can be moved once we contribute back the fix and upgrade.
462
423
class TypicalLogitsWarper (LogitsWarper ):
463
424
r"""
464
425
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
@@ -475,7 +436,7 @@ class TypicalLogitsWarper(LogitsWarper):
475
436
476
437
def __init__ (self , mass : float = 0.9 , filter_value : float = - float ("Inf" ), min_tokens_to_keep : int = 1 ):
477
438
mass = float (mass )
478
- if not (mass > 0 and mass < 1 ):
439
+ if not (0 < mass < 1 ):
479
440
raise ValueError (f"`typical_p` has to be a float > 0 and < 1, but is { mass } " )
480
441
481
442
self .filter_value = filter_value
@@ -496,11 +457,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
496
457
497
458
# Remove tokens with cumulative mass above the threshold
498
459
last_ind = (cumulative_probs < self .mass ).sum (dim = 1 )
499
- last_ind [ last_ind < 0 ] = 0
460
+ last_ind . clamp_ ( 0 , sorted_scores . shape [ - 1 ] - 1 )
500
461
sorted_indices_to_remove = sorted_scores > sorted_scores .gather (1 , last_ind .view (- 1 , 1 ))
501
462
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
502
463
sorted_indices_to_remove [..., : self .min_tokens_to_keep ] = 0
503
464
indices_to_remove = sorted_indices_to_remove .scatter (1 , sorted_indices , sorted_indices_to_remove )
504
465
505
466
scores = scores .masked_fill (indices_to_remove , self .filter_value )
506
- return scores
467
+ return scores
0 commit comments