Skip to content

Commit af11b02

Browse files
[Minor][Refactor] Pass seq_token_counts explicitly (#1425)
Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent 1589931 commit af11b02

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
Qwen3OmniMoeThinkerConfig,
1616
)
1717
from vllm.config import VllmConfig
18-
from vllm.forward_context import get_forward_context
1918
from vllm.logger import init_logger
2019
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
2120
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP
@@ -355,13 +354,14 @@ def forward(
355354

356355
# ========== Stage 3: Code2Wav ==========
357356
elif self.model_stage == "code2wav":
357+
seq_token_counts: list[int] | None = kwargs.get("seq_token_counts")
358+
358359
# Extract codec codes from input
359360
if input_ids.shape[0] % 16 == 0:
360-
ubatch_slices = get_forward_context().ubatch_slices
361-
if ubatch_slices is not None:
362-
max_seq_len = max(ubatch_slices) // 16
363-
batch_size = len(ubatch_slices)
364-
split_codes = torch.split(input_ids, ubatch_slices, dim=0)
361+
if seq_token_counts is not None:
362+
max_seq_len = max(seq_token_counts) // 16
363+
batch_size = len(seq_token_counts)
364+
split_codes = torch.split(input_ids, seq_token_counts, dim=0)
365365
codes = torch.zeros((batch_size, 16, max_seq_len), device=input_ids.device, dtype=input_ids.dtype)
366366
for idx, code in enumerate(split_codes):
367367
seq_len = code.shape[0] // 16
@@ -386,7 +386,7 @@ def forward(
386386
codes = input_ids_flatten.reshape(1, 16, -1)
387387

388388
# Generate audio from codec codes
389-
audio_tensors = self.generate_audio(codes, voice_type)
389+
audio_tensors = self.generate_audio(codes, voice_type, seq_token_counts)
390390

391391
return audio_tensors
392392

@@ -458,16 +458,22 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
458458

459459
# ==================== Audio Generation ====================
460460

461-
def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tensor]:
461+
def generate_audio(
462+
self,
463+
code: torch.Tensor,
464+
voice_type: str,
465+
seq_token_counts: list[int] | None = None,
466+
) -> list[torch.Tensor]:
462467
"""
463468
Generate audio waveform from codec codes.
464469
465470
Args:
466-
code: [8, T] - 8-layer RVQ codec codes
471+
code: [batch, num_quantizers, T] - RVQ codec codes
467472
voice_type: Voice type (not used in Qwen3, kept for compatibility)
473+
seq_token_counts: Token count for each request in batch
468474
469475
Returns:
470-
audio_tensor: [1, waveform_len] - Audio waveform
476+
list of audio waveforms
471477
"""
472478
code2wav_dev = self._module_device(self.code2wav)
473479

@@ -491,13 +497,15 @@ def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tens
491497
talker_codes,
492498
chunk_size=25,
493499
left_context_size=25,
500+
seq_token_counts=seq_token_counts,
494501
)
495502
else:
496503
# Use chunked decode for memory efficiency
497504
audio_tensors = self.code2wav.chunked_decode(
498505
talker_codes,
499506
chunk_size=300,
500507
left_context_size=25,
508+
seq_token_counts=seq_token_counts,
501509
)
502510

503511
return audio_tensors

vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
SnakeBeta,
2323
)
2424
from vllm.config import VllmConfig # type: ignore
25-
from vllm.forward_context import get_forward_context
2625
from vllm.logger import init_logger # type: ignore
2726
from vllm.model_executor.models.utils import ( # type: ignore
2827
AutoWeightsLoader,
@@ -163,6 +162,7 @@ def chunked_decode(
163162
codes: torch.Tensor,
164163
chunk_size: int = 300,
165164
left_context_size: int = 25,
165+
seq_token_counts: list[int] | None = None,
166166
) -> list[torch.Tensor]:
167167
"""
168168
Decode long sequences in chunks to avoid OOM.
@@ -173,6 +173,7 @@ def chunked_decode(
173173
codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes
174174
chunk_size: Number of codec frames per chunk
175175
left_context_size: Number of overlapping frames for context
176+
seq_token_counts: Token count for each request in batch
176177
177178
Returns:
178179
list[torch.Tensor]: Complete waveform decoded from the input
@@ -197,12 +198,10 @@ def chunked_decode(
197198

198199
start_index = end_index
199200

200-
ubatch_slices = get_forward_context().ubatch_slices
201-
if ubatch_slices is not None:
202-
code_seq_lens = [seq_len // self.config.num_quantizers for seq_len in ubatch_slices]
201+
if seq_token_counts is not None:
202+
code_seq_lens = [seq_len // self.config.num_quantizers for seq_len in seq_token_counts]
203203
else:
204204
# Fallback: assume all batch elements share the same sequence length.
205-
# Create one entry per batch so that each element is processed.
206205
code_seq_lens = [codes.shape[-1]] * codes.shape[0]
207206
batch_wav = torch.cat(wavs, dim=-1)
208207
wavs = []
@@ -216,6 +215,7 @@ def chunked_decode_streaming(
216215
codes: torch.Tensor,
217216
chunk_size: int = 25,
218217
left_context_size: int = 25,
218+
seq_token_counts: list[int] | None = None,
219219
) -> list[torch.Tensor]:
220220
"""
221221
Decode long sequences in chunks to avoid OOM.
@@ -226,21 +226,19 @@ def chunked_decode_streaming(
226226
codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes
227227
chunk_size: Number of codec frames per chunk
228228
left_context_size: Number of overlapping frames for context
229+
seq_token_counts: Token count for each request in batch
229230
230231
Returns:
231232
list[torch.Tensor]: Complete waveform decoded from the input
232233
codes. For ``batch_size == 1``, this is a list containing a
233234
single tensor with shape ``[1, waveform_len]``.
234235
"""
235-
# Decode chunk
236236
wavs = []
237237
batch_wav = self(codes)
238-
ubatch_slices = get_forward_context().ubatch_slices
239-
if ubatch_slices is not None:
240-
code_seq_lens = [seq_len // self.config.num_quantizers for seq_len in ubatch_slices]
238+
if seq_token_counts is not None:
239+
code_seq_lens = [n // self.config.num_quantizers for n in seq_token_counts]
241240
else:
242241
# Fallback: assume all batch elements share the same sequence length.
243-
# Create one entry per batch so that each element is processed.
244242
code_seq_lens = [codes.shape[-1]] * codes.shape[0]
245243
for idx, code_seq_len in enumerate(code_seq_lens):
246244
# TODO: need to optimize algorithms, current only support

vllm_omni/worker/gpu_generation_model_runner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def execute_model(
250250
intermediate_tensors,
251251
)
252252

253+
# [Omni] Pass token counts per request for code2wav output slicing
254+
model_kwargs["seq_token_counts"] = tokens
255+
253256
# Set cudagraph mode to none if calc_kv_scales is true.
254257
# KV scales calculation involves dynamic operations that are incompatible
255258
# with CUDA graph capture.
@@ -258,10 +261,6 @@ def execute_model(
258261
# Mark KV scales as calculated after the first forward pass
259262
self.calculate_kv_scales = False
260263

261-
if ubatch_slices_padded is None:
262-
# reuse ubatch_slices_padded for code2wav batching
263-
ubatch_slices_padded = tokens
264-
265264
# Run the model.
266265
# Use persistent buffers for CUDA graphs.
267266
with (

0 commit comments

Comments
 (0)