Skip to content

Commit 9098756

Browse files
committed
fix kernel error for qwen3-omni
Signed-off-by: Rein Yang <ruiruyang2@gmail.com>
1 parent 1ca198e commit 9098756

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ stage_args:
7272
stage_type: llm # Use llm stage type to launch OmniLLM
7373
runtime:
7474
devices: "1"
75-
max_batch_size: 64
75+
max_batch_size: 32
7676
engine_args:
7777
model_stage: code2wav
7878
model_arch: Qwen3OmniMoeForConditionalGeneration

vllm_omni/worker/gpu_ar_model_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,14 @@ def sample_tokens(
416416
if grammar_output is not None:
417417
apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits)
418418

419+
# Correct padding values of prompt_token_ids to match the logits vocabulary size
420+
if logits is not None and not self.input_batch.sampling_metadata.no_penalties:
421+
smd = self.input_batch.sampling_metadata
422+
if smd.prompt_token_ids is not None:
423+
logits_vocab = logits.shape[-1]
424+
if self.input_batch.vocab_size > logits_vocab:
425+
smd.prompt_token_ids = smd.prompt_token_ids.clamp(max=logits_vocab)
426+
419427
with record_function_or_nullcontext("gpu_model_runner: sample"):
420428
sampler_output = self._sample(logits, spec_decode_metadata)
421429

0 commit comments

Comments
 (0)