22
33import torch
44
5- from vllm .model_executor .layers .utils import (
6- apply_penalties as _apply_penalties )
5+ from vllm .model_executor .layers .utils import apply_penalties
76from vllm .utils import is_pin_memory_available , make_tensor_with_pad
87
98
@@ -17,27 +16,30 @@ def apply_min_token_penalties(logits: torch.Tensor,
1716 """
1817 min_tokens_logits_to_penalize : List [Tuple [int , int ]] = []
1918 for index , min_token in enumerate (min_tokens ):
20- if ( len (output_token_ids [index ]) < min_token ) :
19+ if len (output_token_ids [index ]) < min_token :
2120 for stop_token_id in stop_token_ids [index ]:
2221 min_tokens_logits_to_penalize .append ((index , stop_token_id ))
2322 if min_tokens_logits_to_penalize :
2423 logits [tuple (zip (* min_tokens_logits_to_penalize ))] = - float ("inf" )
2524
2625
27- def apply_penalties (logits : torch .Tensor , prompt_token_ids : torch .Tensor ,
28- presence_penalties : torch .Tensor ,
29- frequency_penalties : torch .Tensor ,
30- repetition_penalties : torch .Tensor ,
31- output_token_ids : List [List [int ]]) -> torch .Tensor :
26+ def apply_all_penalties (
27+ logits : torch .Tensor ,
28+ prompt_token_ids : torch .Tensor ,
29+ presence_penalties : torch .Tensor ,
30+ frequency_penalties : torch .Tensor ,
31+ repetition_penalties : torch .Tensor ,
32+ output_token_ids : List [List [int ]],
33+ ) -> torch .Tensor :
3234 """
3335 Applies presence, frequency and repetition penalties to the logits.
3436 """
3537 _ , vocab_size = logits .shape
3638 output_tokens_t = _convert_to_tensors (output_token_ids , vocab_size ,
3739 logits .device )
38- return _apply_penalties (logits , prompt_token_ids , output_tokens_t ,
39- presence_penalties , frequency_penalties ,
40- repetition_penalties )
40+ return apply_penalties (logits , prompt_token_ids , output_tokens_t ,
41+ presence_penalties , frequency_penalties ,
42+ repetition_penalties )
4143
4244
4345def _convert_to_tensors (output_token_ids : List [List [int ]], vocab_size : int ,
0 commit comments