@@ -71,13 +71,18 @@ def forward(
71
71
logits .div_ (t .unsqueeze (dim = 1 ))
72
72
73
73
# Apply top-p and top-k truncation.
74
- top_ps , top_ks = _get_top_p_top_k (input_metadata , self .vocab_size )
74
+ top_ps , top_ks , min_ps = _get_top_p_top_k_min_p (
75
+ input_metadata , self .vocab_size )
75
76
assert len (top_ps ) == len (top_ks ) == logits .shape [0 ]
76
77
do_top_p = any (p < 1.0 - _SAMPLING_EPS for p in top_ps )
77
78
do_top_k = any (k != self .vocab_size for k in top_ks )
78
79
if do_top_p or do_top_k :
79
80
logits = _apply_top_p_top_k (logits , top_ps , top_ks )
80
81
82
+ do_min_p = any (mp > _SAMPLING_EPS for mp in min_ps )
83
+ if do_min_p :
84
+ logits = _apply_min_p (logits , min_ps )
85
+
81
86
# We use float32 for probabilities and log probabilities.
82
87
# Compute the probabilities.
83
88
probs = torch .softmax (logits , dim = - 1 , dtype = torch .float )
@@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
261
266
return temperatures
262
267
263
268
264
- def _get_top_p_top_k (
269
+ def _get_top_p_top_k_min_p (
265
270
input_metadata : InputMetadata ,
266
271
vocab_size : int ,
267
- ) -> Tuple [List [float ], List [int ]]:
272
+ ) -> Tuple [List [float ], List [int ], List [ float ] ]:
268
273
top_ps : List [float ] = []
269
274
top_ks : List [int ] = []
275
+ min_ps : List [float ] = []
270
276
for i , seq_group in enumerate (input_metadata .seq_groups ):
271
277
seq_ids , sampling_params = seq_group
272
278
top_p = sampling_params .top_p
279
+ min_p = sampling_params .min_p
273
280
# k should not be greater than the vocab size.
274
281
top_k = min (sampling_params .top_k , vocab_size )
275
282
# k=-1 means no truncation.
@@ -279,9 +286,11 @@ def _get_top_p_top_k(
279
286
prompt_len = input_metadata .prompt_lens [i ]
280
287
top_ps += [top_p ] * (prompt_len - 1 )
281
288
top_ks += [top_k ] * (prompt_len - 1 )
289
+ min_ps += [min_p ] * (prompt_len - 1 )
282
290
top_ps += [top_p ] * len (seq_ids )
283
291
top_ks += [top_k ] * len (seq_ids )
284
- return top_ps , top_ks
292
+ min_ps += [min_p ] * len (seq_ids )
293
+ return top_ps , top_ks , min_ps
285
294
286
295
287
296
def _apply_top_p_top_k (
@@ -313,6 +322,24 @@ def _apply_top_p_top_k(
313
322
return logits
314
323
315
324
325
+ def _apply_min_p (
326
+ logits : torch .Tensor ,
327
+ min_ps : List [float ],
328
+ ) -> torch .Tensor :
329
+ """
330
+ Adapted from
331
+ https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
332
+ """
333
+ min_p = torch .tensor (min_ps , dtype = logits .dtype , device = logits .device )
334
+ probs = torch .softmax (logits , dim = - 1 )
335
+ top_probs , _ = probs .max (dim = - 1 , keepdim = True )
336
+ scaled_min_p = min_p .unsqueeze (dim = 1 ) * top_probs
337
+ tokens_to_remove = probs < scaled_min_p
338
+ logits = logits .masked_fill (tokens_to_remove , - float ("inf" ))
339
+
340
+ return logits
341
+
342
+
316
343
def _greedy_sample (
317
344
selected_seq_groups : List [Tuple [List [int ], SamplingParams ]],
318
345
logprobs : torch .Tensor ,
0 commit comments