Skip to content

Commit addc714

Browse files
committed
fix: Fix typical_p value edge case causing tensor OOB errors
In TypicalLogitsWarper we need to ensure that last_ind is a valid index for the vocab size. In low precision situations, it can end up that the cumulative mass covers the entire vocab even when it is < 1 (specifically observed for typical_p = 0.99 with bfloat16). Also clean up patched version of TopPLogitsWarper now that the fix is in transformers.
1 parent 7c539a8 commit addc714

File tree

2 files changed

+7
-47
lines changed

2 files changed

+7
-47
lines changed

router/src/queue.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
411411
theta => Some(theta),
412412
},
413413
length_penalty: parameters.length_penalty
414-
.map(|lp| LengthPenalty {
415-
start_index: lp.0,
416-
decay_factor: lp.1,
414+
.map(|(start_index, decay_factor)| LengthPenalty {
415+
start_index, decay_factor
417416
}),
418417
}
419418
}

server/text_generation_server/utils/logits_process.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
LogitsProcessor,
1212
TemperatureLogitsWarper,
1313
TopKLogitsWarper,
14-
#TopPLogitsWarper,
14+
TopPLogitsWarper,
1515
#TypicalLogitsWarper,
1616
)
1717

@@ -419,46 +419,7 @@ def filter(self, indices):
419419
return None
420420

421421

422-
423-
424-
# These are fixed versions of the classese in transformers. Remove them after upgrading to transformers >= 4.30.
425-
# See https://github.com/huggingface/transformers/pull/24111
426-
class TopPLogitsWarper(LogitsWarper):
427-
"""
428-
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
429-
430-
Args:
431-
top_p (`float`):
432-
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
433-
higher are kept for generation.
434-
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
435-
All filtered values will be set to this float value.
436-
min_tokens_to_keep (`int`, *optional*, defaults to 1):
437-
Minimum number of tokens that cannot be filtered.
438-
"""
439-
440-
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
441-
top_p = float(top_p)
442-
if top_p < 0 or top_p > 1.0:
443-
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
444-
445-
self.top_p = top_p
446-
self.filter_value = filter_value
447-
self.min_tokens_to_keep = min_tokens_to_keep
448-
449-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
450-
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
451-
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
452-
453-
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
454-
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
455-
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
456-
457-
# scatter sorted tensors to original indexing
458-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
459-
scores = scores.masked_fill(indices_to_remove, self.filter_value)
460-
return scores
461-
422+
# This is a fixed version of the class in transformers. Can be moved once we contribute back the fix and upgrade.
462423
class TypicalLogitsWarper(LogitsWarper):
463424
r"""
464425
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
@@ -475,7 +436,7 @@ class TypicalLogitsWarper(LogitsWarper):
475436

476437
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
477438
mass = float(mass)
478-
if not (mass > 0 and mass < 1):
439+
if not (0 < mass < 1):
479440
raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
480441

481442
self.filter_value = filter_value
@@ -496,11 +457,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
496457

497458
# Remove tokens with cumulative mass above the threshold
498459
last_ind = (cumulative_probs < self.mass).sum(dim=1)
499-
last_ind[last_ind < 0] = 0
460+
last_ind.clamp_(0, sorted_scores.shape[-1] - 1)
500461
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
501462
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
502463
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
503464
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
504465

505466
scores = scores.masked_fill(indices_to_remove, self.filter_value)
506-
return scores
467+
return scores

0 commit comments

Comments
 (0)