10
10
from vllm .sampling_params import SamplingParams , SamplingType
11
11
from vllm .sequence import (PromptLogprobs , SampleLogprobs , SamplerOutput ,
12
12
SequenceData , SequenceGroupOutput , SequenceOutput )
13
+ from vllm .utils import is_neuron
13
14
14
15
15
16
class Sampler (nn .Module ):
@@ -32,6 +33,8 @@ def __init__(self,
32
33
org_vocab_size : Optional [int ] = None ) -> None :
33
34
super ().__init__ ()
34
35
self .vocab_size = vocab_size
36
+ # Transformers-neuronx generate outputs as logits directly.
37
+ self .logits_as_hidden_states = is_neuron ()
35
38
# original vocabulary size (without LoRA).
36
39
self .org_vocab_size = org_vocab_size or vocab_size
37
40
@@ -55,10 +58,14 @@ def forward(
55
58
embedding_bias : Optional [torch .Tensor ] = None ,
56
59
) -> Optional [SamplerOutput ]:
57
60
# Get the hidden states that we use for sampling.
58
- hidden_states = _prune_hidden_states (hidden_states , sampling_metadata )
61
+ if self .logits_as_hidden_states :
62
+ logits = hidden_states
63
+ else :
64
+ hidden_states = _prune_hidden_states (hidden_states ,
65
+ sampling_metadata )
59
66
60
- # Get the logits for the next tokens.
61
- logits = self ._get_logits (hidden_states , embedding , embedding_bias )
67
+ # Get the logits for the next tokens.
68
+ logits = self ._get_logits (hidden_states , embedding , embedding_bias )
62
69
63
70
# Only perform sampling in the driver worker.
64
71
# Note: `_get_logits` is still distributed across TP workers because
@@ -395,7 +402,8 @@ def _sample(
395
402
sample_metadata [sampling_type ] = (seq_group_ids , seq_groups ,
396
403
is_prompts , sample_indices )
397
404
if sampling_type == SamplingType .GREEDY :
398
- greedy_samples = torch .argmax (logprobs [sample_indices ], dim = - 1 )
405
+ greedy_samples = torch .argmax (logprobs [sample_indices .long ()],
406
+ dim = - 1 )
399
407
elif sampling_type in (SamplingType .RANDOM , SamplingType .RANDOM_SEED ):
400
408
max_best_of = 1
401
409
for seq_group , is_prompt in zip (seq_groups , is_prompts ):
@@ -407,7 +415,7 @@ def _sample(
407
415
"generators" : sampling_metadata .generators ,
408
416
}
409
417
multinomial_samples [sampling_type ] = _multinomial (
410
- probs [sample_indices ], max_best_of , ** seeded_args )
418
+ probs [sample_indices . long () ], max_best_of , ** seeded_args )
411
419
elif sampling_type == SamplingType .BEAM :
412
420
beam_search_logprobs = logprobs [sample_indices ]
413
421
else :
0 commit comments