diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml index 802f8dd249..eb5bf71012 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml @@ -72,7 +72,7 @@ stage_args: stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "1" - max_batch_size: 64 + max_batch_size: 32 engine_args: model_stage: code2wav model_arch: Qwen3OmniMoeForConditionalGeneration diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index e332db8b0a..cf261429fd 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -416,6 +416,14 @@ def sample_tokens( if grammar_output is not None: apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits) + # Correct padding values of prompt_token_ids to match the logits vocabulary size + if logits is not None and not self.input_batch.sampling_metadata.no_penalties: + smd = self.input_batch.sampling_metadata + if smd.prompt_token_ids is not None: + logits_vocab = logits.shape[-1] + if self.input_batch.vocab_size > logits_vocab: + smd.prompt_token_ids = smd.prompt_token_ids.clamp(max=logits_vocab) + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata)