Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ede288c
opt
LJH-LBJ Feb 11, 2026
94766e1
Merge remote-tracking branch 'origin/main' into Supports-configurable…
LJH-LBJ Feb 19, 2026
b5596a5
change save_async
LJH-LBJ Feb 19, 2026
bc21ac1
fix bug
LJH-LBJ Feb 20, 2026
e330708
transfer letf_context_size by additional_information
LJH-LBJ Feb 20, 2026
9fe559c
del async_chunk_config in OmniEngineArgs
LJH-LBJ Feb 21, 2026
d2a1c28
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 21, 2026
3c73e35
fix pre-commit
LJH-LBJ Feb 21, 2026
3d23a72
opt
LJH-LBJ Feb 21, 2026
ecd6e53
opt
LJH-LBJ Feb 22, 2026
9303949
opt
LJH-LBJ Feb 22, 2026
395c286
fix pre-commit
LJH-LBJ Feb 22, 2026
2be2eff
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 22, 2026
e11bc9a
add batch left_context_size
LJH-LBJ Feb 24, 2026
38d74b6
Merge remote-tracking branch 'origin/main' into Supports-configurable…
LJH-LBJ Feb 24, 2026
c6bd3b0
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 24, 2026
af3d28d
update_additional_information in code2wave
LJH-LBJ Feb 26, 2026
e654c58
opt
LJH-LBJ Feb 26, 2026
74ed57c
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 26, 2026
9ce0554
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 26, 2026
ea607bd
add left_context_size for qwen3 tts
LJH-LBJ Feb 27, 2026
4eb97f0
Obtain the parameter from the transfer_manager
LJH-LBJ Feb 27, 2026
602975d
remove async_chunk_config
LJH-LBJ Feb 27, 2026
f5889a0
Merge remote-tracking branch 'origin/main' into Supports-configurable…
LJH-LBJ Feb 27, 2026
4985891
fix bug
LJH-LBJ Feb 27, 2026
c8b0e76
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 27, 2026
54f31ae
opt
LJH-LBJ Feb 28, 2026
7e77fed
Merge branch 'Supports-configurable-chunk_size-and-left_context_size'…
LJH-LBJ Feb 28, 2026
2af512e
opt
LJH-LBJ Feb 28, 2026
fb04697
opt
LJH-LBJ Feb 28, 2026
b1b441c
opt
LJH-LBJ Feb 28, 2026
3692b96
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 28, 2026
1359e59
opt
LJH-LBJ Feb 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm_omni/core/sched/omni_generation_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def schedule(self) -> SchedulerOutput:
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
scheduled_encoder_inputs: dict[str, list[int]] = {}
cached_prompt_token_ids: dict[str, list[int]] = {}
cached_additional_information: dict[str, dict | None] = {}

# Temporary queue: preserve waiting order, do not disturb non-diffusion requests
skipped_waiting_requests = create_request_queue(self.policy)
Expand Down Expand Up @@ -105,6 +106,7 @@ def schedule(self) -> SchedulerOutput:
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
cached_prompt_token_ids[request.request_id] = request.prompt_token_ids
cached_additional_information[request.request_id] = getattr(request, "additional_information", None)
token_budget -= num_new_tokens
scheduled_running_reqs.append(request)
req_index += 1
Expand Down Expand Up @@ -225,6 +227,7 @@ def schedule(self) -> SchedulerOutput:
num_computed_tokens=cached_reqs_data.num_computed_tokens,
num_output_tokens=cached_reqs_data.num_output_tokens,
prompt_token_ids=cached_prompt_token_ids,
additional_information=cached_additional_information,
)

total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class OmniCachedRequestData(CachedRequestData):
"""

prompt_token_ids: dict[str, list[int]]
additional_information: dict[str, dict | None]


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def _poll_single_request(self, request: Request):

new_ids = payload_data.get("code_predictor_codes", [])
request.prompt_token_ids = new_ids
# Pass additional fields (like left_context_size) to the request
# Only pass chunk context metadata in additional_information
request.additional_information = {}
if "left_context_size" in payload_data:
request.additional_information["left_context_size"] = payload_data["left_context_size"]
request.num_computed_tokens = 0

# Empty chunk with more data expected: keep polling.
Expand Down
20 changes: 16 additions & 4 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from collections.abc import Iterable
from functools import cached_property
from typing import Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.requires_raw_input_tokens = True

elif self.model_stage == "code2wav":
self.enable_update_additional_information = True
self.thinker = None
self.talker = None
# Initialize code2wav (codec codes → audio waveform)
Expand Down Expand Up @@ -254,7 +256,7 @@ def forward(
codec: torch.Tensor | None = None,
sampling_metadata: SamplingMetadata | None = None,
logits_index: int | None = None,
additional_information: dict[str, object] | None = None,
runtime_additional_information: list[dict[str, Any]] | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors | OmniOutput:
"""
Expand Down Expand Up @@ -359,7 +361,15 @@ def forward(
codes = input_ids_flatten.reshape(1, 16, -1)

# Generate audio from codec codes
audio_tensors = self.generate_audio(codes, voice_type, seq_token_counts)
# Get every request's left_context_size from runtime_additional_information (passed via kwargs)
left_context_size = []
if runtime_additional_information is not None:
for info in runtime_additional_information:
if "left_context_size" in info:
left_context_size.append(info["left_context_size"])
else:
logger.debug("No additional_information provided to code2wav stage.")
audio_tensors = self.generate_audio(codes, voice_type, left_context_size, seq_token_counts)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning fires for every non-async-chunk call (or whenever additional_information is None), which could be very noisy in production. Is this intentional for debugging only, or should it be logger.debug instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will use logger.debug instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

return audio_tensors

Expand Down Expand Up @@ -435,6 +445,7 @@ def generate_audio(
self,
code: torch.Tensor,
voice_type: str,
left_context_size: list[int] | None = None,
seq_token_counts: list[int] | None = None,
) -> list[torch.Tensor]:
"""
Expand All @@ -443,6 +454,7 @@ def generate_audio(
Args:
code: [batch, num_quantizers, T] - RVQ codec codes
voice_type: Voice type (not used in Qwen3, kept for compatibility)
left_context_size: Left context size for streaming decode
seq_token_counts: Token count for each request in batch

Returns:
Expand All @@ -466,10 +478,10 @@ def generate_audio(
talker_codes = talker_codes.expand(1, 16, -1)

if self.vllm_config.model_config.async_chunk:
# Only use left_context_size from additional information
audio_tensors = self.code2wav.chunked_decode_streaming(
talker_codes,
chunk_size=25,
left_context_size=25,
left_context_size=left_context_size,
seq_token_counts=seq_token_counts,
)
else:
Expand Down
25 changes: 14 additions & 11 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,18 @@ def chunked_decode(
def chunked_decode_streaming(
self,
codes: torch.Tensor,
chunk_size: int = 25,
left_context_size: int = 25,
left_context_size: list[int] | None = None,
seq_token_counts: list[int] | None = None,
) -> list[torch.Tensor]:
"""
Decode long sequences in chunks to avoid OOM.

Uses overlapping chunks with left context to avoid boundary artifacts.

No longer need chunk size here, which is different from chunked_decode

Args:
codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes
chunk_size: Number of codec frames per chunk
left_context_size: Number of overlapping frames for context
seq_token_counts: Token count for each request in batch

Expand All @@ -233,6 +233,13 @@ def chunked_decode_streaming(
codes. For ``batch_size == 1``, this is a list containing a
single tensor with shape ``[1, waveform_len]``.
"""
if not (left_context_size and seq_token_counts and len(left_context_size) == len(seq_token_counts)):
logger.warning(
"chunked_decode_streaming: missing/invalid left_context_size or seq_token_counts; "
"defaulting to left_context_size=zeros(len=codes.shape[0])."
)
left_context_size = [0] * codes.shape[0]
# Decode chunk
wavs = []
batch_wav = self(codes)
if seq_token_counts is not None:
Expand All @@ -241,14 +248,10 @@ def chunked_decode_streaming(
# Fallback: assume all batch elements share the same sequence length.
code_seq_lens = [codes.shape[-1]] * codes.shape[0]
for idx, code_seq_len in enumerate(code_seq_lens):
# TODO: need to optimize algorithms, current only support
# chunk_size = left_context_size = 25
if code_seq_len <= chunk_size:
context_size = 0
else:
context_size = left_context_size
# Remove context from output (context_size * total_upsample samples)
wav_chunk = batch_wav[idx, :, context_size * self.total_upsample : code_seq_len * self.total_upsample]
# Remove context from output (left_context_size * total_upsample samples)
wav_chunk = batch_wav[
idx, :, left_context_size[idx] * self.total_upsample : code_seq_len * self.total_upsample
]
wavs.append(wav_chunk)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunk_size is declared as a parameter of chunked_decode_streaming but is no longer used anywhere in the method body after this change. Is that intentional? If it is only there for API compatibility, might be worth a comment or removing it entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will removed chunk_size entirely

return wavs

Expand Down
24 changes: 18 additions & 6 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.have_multimodal_outputs = True
self.has_preprocess = False
self.has_postprocess = False
self.enable_update_additional_information = True
# Generation-only stage (no logits / sampling).
self.requires_raw_input_tokens = True

Expand Down Expand Up @@ -146,6 +147,7 @@ def forward(
positions: torch.Tensor | None = None,
intermediate_tensors: Any = None,
inputs_embeds: torch.Tensor | None = None,
runtime_additional_information: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> OmniOutput:
"""Decode codec codes into audio waveform.
Expand Down Expand Up @@ -177,18 +179,28 @@ def forward(
ids = input_ids.reshape(-1).to(dtype=torch.long)
request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts"))

# Parse each request: extract ctx_frames, validate, reshape codes.
# input_ids layout per request: [codec_context_frames, *flat_codes]
# Parse each request: extract left_context_size, validate, reshape codes.
# input_ids layout per request: [*flat_codes]
# where flat_codes is codebook-major [q*F].
parsed = [] # (ctx_frames, actual_frames)
parsed = [] # (left_context_size, actual_frames)
valid_codes = []
valid_indices = []
# Get left_context_size from runtime_additional_information (passed via kwargs)
left_context_size = [0] * len(request_ids_list)
if runtime_additional_information is not None:
for i, info in enumerate(runtime_additional_information):
if i >= len(left_context_size):
break
if "left_context_size" in info:
left_context_size[i] = info["left_context_size"]
else:
logger.debug("No additional_information provided to code2wav stage.")
for i, req_ids in enumerate(request_ids_list):
if req_ids.numel() < 2:
if req_ids.numel() < 1:
parsed.append((0, 0))
continue
ctx_frames = int(req_ids[0].item())
flat = req_ids[1:]
ctx_frames = left_context_size[i]
flat = req_ids
n = flat.numel()
# Warmup / dummy_run: not divisible by num_quantizers.
if n == 0 or n % q != 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ stage_args:
final_output: true
final_output_type: text
is_comprehension: true
# Use named connector to apply runtime.connectors.extra.
output_connectors:
to_stage_1: connector_of_shared_memory
default_sampling_params:
temperature: 0.4
top_p: 0.9
Expand Down Expand Up @@ -60,6 +63,9 @@ stage_args:
engine_input_source: [0]
# final_output: true
# final_output_type: text
# Distributed connector configuration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to make changes to yaml? @LJH-LBJ @amy-why-3459 @linyueqian

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the Qwen3-TTS side: existing YAMLs won't break since the stage input processor falls back to codec_chunk_frames=25 and codec_left_context_frames=25 by default. That said, the YAML changes are good to have. They wire up named connectors so these values are explicit and configurable rather than relying on hardcoded defaults, which is the whole point of this PR.

input_connectors:
from_stage_0: connector_of_shared_memory
default_sampling_params:
temperature: 0.9
top_k: 50
Expand Down Expand Up @@ -99,3 +105,13 @@ stage_args:
seed: 42
detokenize: True
repetition_penalty: 1.1

runtime:

connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
extra:
# Align with Omni: small chunks with sufficient context overlap.
codec_chunk_frames: 25 # code2wav decode chunk size
codec_left_context_frames: 25 # code2wav left context size
28 changes: 19 additions & 9 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def talker2code2wav_async_chunk(
if "code_predictor_codes" not in pooling_output:
return None

connector = getattr(transfer_manager, "connector", None)
raw_cfg = getattr(connector, "config", {}) or {}
cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {}
chunk_size_config = int(cfg.get("codec_chunk_frames", 25))
left_context_size_config = int(cfg.get("codec_left_context_frames", 25))

code_predictor_codes = pooling_output["code_predictor_codes"]

if code_predictor_codes is None:
Expand All @@ -244,23 +250,27 @@ def talker2code2wav_async_chunk(
return None

request_id = request.external_req_id
chunk_size = left_context_size = 25
transfer_manager.code_prompt_token_ids[request_id].append(codec_codes)
length = len(transfer_manager.code_prompt_token_ids[request_id])
chunk_length = length % chunk_size
chunk_length = length % chunk_size_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when length == context_length (i.e., the first chunk)? min(0, left_context_size_config) gives 0, so end_index = 0 + chunk_size_config. That looks correct. But for the very last chunk when chunk_length != 0, context_length = chunk_length which could be small. Then left_context_size = min(length - chunk_length, left_context_size_config). Have you verified this against the table in the PR description for the final chunk case?

Copy link
Contributor Author

@LJH-LBJ LJH-LBJ Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when the first chunk gets length == context_length, end_index = min(context_length, left_context_size + context_length) = min(4, 0+4), i.e. context_length=4.

when the last chunk is smaller than chunk_size_config, left_context_size = min(length - chunk_length, left_context_size_config) = min(37-1, 25) = 25 in the table in PR.

I think it is correct now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it, that tracks. Thanks for walking through it.

if chunk_length != 0 and not is_finished:
return None

context_length = chunk_length if chunk_length != 0 else chunk_size
context_length = chunk_length if chunk_length != 0 else chunk_size_config
# ensure left context does not exceed available length
left_context_size = max(0, min(length - context_length, left_context_size_config))
end_index = min(length, left_context_size + context_length)

codes = (
torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:])
.transpose(0, 1)
.reshape(-1)
.tolist()
)

info = {
"code_predictor_codes": (
torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:])
.transpose(0, 1)
.reshape(-1)
.tolist()
),
"code_predictor_codes": codes,
"left_context_size": left_context_size,
"finished": torch.tensor(is_finished, dtype=torch.bool),
}
return info
Expand Down
25 changes: 13 additions & 12 deletions vllm_omni/model_executor/stage_input_processors/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def talker2code2wav_async_chunk(
connector = getattr(transfer_manager, "connector", None)
raw_cfg = getattr(connector, "config", {}) or {}
cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {}
chunk_size = int(cfg.get("codec_chunk_frames", 25))
left_context_size = int(cfg.get("codec_left_context_frames", 25))
if chunk_size <= 0 or left_context_size < 0:
chunk_size_config = int(cfg.get("codec_chunk_frames", 25))
left_context_size_config = int(cfg.get("codec_left_context_frames", 25))
if chunk_size_config <= 0 or left_context_size_config < 0:
raise ValueError(
f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, "
f"codec_left_context_frames={left_context_size}"
f"Invalid codec chunk config: codec_chunk_frames={chunk_size_config}, "
f"codec_left_context_frames={left_context_size_config}"
)
length = len(transfer_manager.code_prompt_token_ids[request_id])

Expand All @@ -59,22 +59,23 @@ def talker2code2wav_async_chunk(
}
return None

chunk_length = length % chunk_size
chunk_length = length % chunk_size_config

if chunk_length != 0 and not finished:
return None

context_length = chunk_length if chunk_length != 0 else chunk_size
end_index = min(length, left_context_size + context_length)
ctx_frames = max(0, int(end_index - context_length))
context_length = chunk_length if chunk_length != 0 else chunk_size_config
end_index = min(length, left_context_size_config + context_length)
left_context_size = max(0, int(end_index - context_length))
window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:]

# Pack context + chunk into codebook-major flat codes for adapter.
code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist()

# Build final prompt_token_ids with ctx_frames header for Qwen3-TTS Code2Wav.
# The model expects input_ids layout: [ctx_frames, *flat_codes].
# Build final prompt_token_ids and left_context_size header for Qwen3-TTS Code2Wav.
# The model expects input_ids layout: [*flat_codes].
return {
"code_predictor_codes": [int(ctx_frames)] + code_predictor_codes,
"code_predictor_codes": code_predictor_codes,
"left_context_size": left_context_size,
"finished": torch.tensor(finished, dtype=torch.bool),
}
10 changes: 8 additions & 2 deletions vllm_omni/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,12 +1151,18 @@ def _preprocess(
# Prefill: overlay prompt_embeds and collect additional_information
self._collect_additional_information_for_prefill(num_scheduled_tokens_np)

# Keep per-request additional_information in sync for both new and
# cached requests. This is required for stages without preprocess
# (e.g., code2wav) so runtime_additional_information can be refreshed
# from scheduler cached infos on every step.
if hasattr(self.model, "has_preprocess") or hasattr(self.model, "enable_update_additional_information"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check has_preprocess here when enable_update_additional_information is the intended gate? Models with has_preprocess=True but no enable_update_additional_information would hit this path unnecessarily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original logic, when has_preprocess=True, update_additional_information can be called. Now I don’t want to change the original logic; I just want the code2wav stage to also be able to update additional_information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, makes sense.

if self.vllm_config.model_config.async_chunk:
self._update_additional_information(scheduler_output)

if hasattr(self.model, "has_preprocess") and self.model.has_preprocess:
# Overlay custom prompt_embeds per request for the prompt portion;
# collect additional_information (tensor/list) for prefill portion only
decode_req_ids = []
if self.vllm_config.model_config.async_chunk:
self._update_additional_information(scheduler_output)
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests.get(req_id)
req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
Expand Down