Skip to content

Commit e87557b

Browse files
authored
Support Min P Sampler (#1642)
1 parent dcc543a commit e87557b

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,18 @@ def forward(
7171
logits.div_(t.unsqueeze(dim=1))
7272

7373
# Apply top-p and top-k truncation.
74-
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
74+
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
75+
input_metadata, self.vocab_size)
7576
assert len(top_ps) == len(top_ks) == logits.shape[0]
7677
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
7778
do_top_k = any(k != self.vocab_size for k in top_ks)
7879
if do_top_p or do_top_k:
7980
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
8081

82+
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
83+
if do_min_p:
84+
logits = _apply_min_p(logits, min_ps)
85+
8186
# We use float32 for probabilities and log probabilities.
8287
# Compute the probabilities.
8388
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
@@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
261266
return temperatures
262267

263268

264-
def _get_top_p_top_k(
269+
def _get_top_p_top_k_min_p(
265270
input_metadata: InputMetadata,
266271
vocab_size: int,
267-
) -> Tuple[List[float], List[int]]:
272+
) -> Tuple[List[float], List[int], List[float]]:
268273
top_ps: List[float] = []
269274
top_ks: List[int] = []
275+
min_ps: List[float] = []
270276
for i, seq_group in enumerate(input_metadata.seq_groups):
271277
seq_ids, sampling_params = seq_group
272278
top_p = sampling_params.top_p
279+
min_p = sampling_params.min_p
273280
# k should not be greater than the vocab size.
274281
top_k = min(sampling_params.top_k, vocab_size)
275282
# k=-1 means no truncation.
@@ -279,9 +286,11 @@ def _get_top_p_top_k(
279286
prompt_len = input_metadata.prompt_lens[i]
280287
top_ps += [top_p] * (prompt_len - 1)
281288
top_ks += [top_k] * (prompt_len - 1)
289+
min_ps += [min_p] * (prompt_len - 1)
282290
top_ps += [top_p] * len(seq_ids)
283291
top_ks += [top_k] * len(seq_ids)
284-
return top_ps, top_ks
292+
min_ps += [min_p] * len(seq_ids)
293+
return top_ps, top_ks, min_ps
285294

286295

287296
def _apply_top_p_top_k(
@@ -313,6 +322,24 @@ def _apply_top_p_top_k(
313322
return logits
314323

315324

325+
def _apply_min_p(
326+
logits: torch.Tensor,
327+
min_ps: List[float],
328+
) -> torch.Tensor:
329+
"""
330+
Adapted from
331+
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
332+
"""
333+
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
334+
probs = torch.softmax(logits, dim=-1)
335+
top_probs, _ = probs.max(dim=-1, keepdim=True)
336+
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
337+
tokens_to_remove = probs < scaled_min_p
338+
logits = logits.masked_fill(tokens_to_remove, -float("inf"))
339+
340+
return logits
341+
342+
316343
def _greedy_sample(
317344
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
318345
logprobs: torch.Tensor,

vllm/sampling_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class SamplingParams:
5252
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
5353
top_k: Integer that controls the number of top tokens to consider. Set
5454
to -1 to consider all tokens.
55+
min_p: Float that represents the minimum probability for a token to be
56+
considered, relative to the probability of the most likely token.
57+
Must be in [0, 1]. Set to 0 to disable this.
5558
use_beam_search: Whether to use beam search instead of sampling.
5659
length_penalty: Float that penalizes sequences based on their length.
5760
Used in beam search.
@@ -94,6 +97,7 @@ def __init__(
9497
temperature: float = 1.0,
9598
top_p: float = 1.0,
9699
top_k: int = -1,
100+
min_p: int = 0.0,
97101
use_beam_search: bool = False,
98102
length_penalty: float = 1.0,
99103
early_stopping: Union[bool, str] = False,
@@ -115,6 +119,7 @@ def __init__(
115119
self.temperature = temperature
116120
self.top_p = top_p
117121
self.top_k = top_k
122+
self.min_p = min_p
118123
self.use_beam_search = use_beam_search
119124
self.length_penalty = length_penalty
120125
self.early_stopping = early_stopping
@@ -167,6 +172,9 @@ def _verify_args(self) -> None:
167172
if self.top_k < -1 or self.top_k == 0:
168173
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
169174
f"got {self.top_k}.")
175+
if not 0.0 <= self.min_p <= 1.0:
176+
raise ValueError("min_p must be in [0, 1], got "
177+
f"{self.min_p}.")
170178
if self.max_tokens < 1:
171179
raise ValueError(
172180
f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -228,6 +236,7 @@ def __repr__(self) -> str:
228236
f"temperature={self.temperature}, "
229237
f"top_p={self.top_p}, "
230238
f"top_k={self.top_k}, "
239+
f"min_p={self.min_p}, "
231240
f"use_beam_search={self.use_beam_search}, "
232241
f"length_penalty={self.length_penalty}, "
233242
f"early_stopping={self.early_stopping}, "

0 commit comments

Comments
 (0)