Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ede288c
opt
LJH-LBJ Feb 11, 2026
94766e1
Merge remote-tracking branch 'origin/main' into Supports-configurable…
LJH-LBJ Feb 19, 2026
b5596a5
change save_async
LJH-LBJ Feb 19, 2026
bc21ac1
fix bug
LJH-LBJ Feb 20, 2026
e330708
transfer letf_context_size by additional_information
LJH-LBJ Feb 20, 2026
9fe559c
del async_chunk_config in OmniEngineArgs
LJH-LBJ Feb 21, 2026
d2a1c28
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 21, 2026
3c73e35
fix pre-commit
LJH-LBJ Feb 21, 2026
3d23a72
opt
LJH-LBJ Feb 21, 2026
ecd6e53
opt
LJH-LBJ Feb 22, 2026
9303949
opt
LJH-LBJ Feb 22, 2026
395c286
fix pre-commit
LJH-LBJ Feb 22, 2026
2be2eff
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 22, 2026
e11bc9a
add batch left_context_size
LJH-LBJ Feb 24, 2026
38d74b6
Merge remote-tracking branch 'origin/main' into Supports-configurable…
LJH-LBJ Feb 24, 2026
c6bd3b0
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 24, 2026
af3d28d
update_additional_information in code2wave
LJH-LBJ Feb 26, 2026
e654c58
opt
LJH-LBJ Feb 26, 2026
74ed57c
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 26, 2026
9ce0554
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 26, 2026
ea607bd
add left_context_size for qwen3 tts
LJH-LBJ Feb 27, 2026
4eb97f0
Obtain the parameter from the transfer_manager
LJH-LBJ Feb 27, 2026
602975d
remove async_chunk_config
LJH-LBJ Feb 27, 2026
f5889a0
Merge remote-tracking branch 'origin/main' into Supports-configurable…
LJH-LBJ Feb 27, 2026
4985891
fix bug
LJH-LBJ Feb 27, 2026
c8b0e76
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 27, 2026
54f31ae
opt
LJH-LBJ Feb 28, 2026
7e77fed
Merge branch 'Supports-configurable-chunk_size-and-left_context_size'…
LJH-LBJ Feb 28, 2026
2af512e
opt
LJH-LBJ Feb 28, 2026
fb04697
opt
LJH-LBJ Feb 28, 2026
b1b441c
opt
LJH-LBJ Feb 28, 2026
3692b96
Merge branch 'main' into Supports-configurable-chunk_size-and-left_co…
LJH-LBJ Feb 28, 2026
1359e59
opt
LJH-LBJ Feb 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm_omni/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class OmniModelConfig(ModelConfig):
Attributes:
stage_id: Identifier for the stage in a multi-stage pipeline (default: 0)
async_chunk: If set to True, perform async chunk
async_chunk_config: Configuration dictionary for async chunk processing,
including keys like "chunk_size" and "left_context_size".
chunk_size: code2wav decode chunk size
left_context_size: code2wav left context size
(default: {"chunk_size": 25, "left_context_size": 25})
model_stage: Stage type identifier, e.g., "thinker" or "talker"
(default: "thinker")
model_arch: Model architecture name
Expand All @@ -45,6 +50,12 @@ class OmniModelConfig(ModelConfig):

stage_id: int = 0
async_chunk: bool = False
async_chunk_config: dict[str, Any] = field(
default_factory=lambda: {
"chunk_size": 25, # code2wav decode chunk size
"left_context_size": 25, # code2wav left context size
}
)
model_stage: str = "thinker"
model_arch: str = "Qwen2_5OmniForConditionalGeneration"
worker_type: str | None = None
Expand Down
3 changes: 2 additions & 1 deletion vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, *args, **kwargs):
self.chunk_transfer_adapter = None
if getattr(model_config, "async_chunk", False):
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)
self.async_chunk_config = getattr(model_config, "async_chunk_config", None)

def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
Expand Down Expand Up @@ -353,7 +354,7 @@ def update_from_output(
)
)
if self.chunk_transfer_adapter is not None:
self.chunk_transfer_adapter.save_async(pooler_output, request)
self.chunk_transfer_adapter.save_async(self.async_chunk_config, pooler_output, request)
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ def load_async(self, request: Request):
with self.lock:
self._pending_load_reqs[request_id] = request

def save_async(self, pooling_output: torch.Tensor | None = None, request: Request | None = None):
def save_async(
self,
async_chunk_config: dict[str, Any],
pooling_output: torch.Tensor | None = None,
request: Request | None = None,
) -> None:
"""Build and enqueue one chunk for asynchronous sending.

Payload extraction is executed in the caller thread via
Expand All @@ -122,6 +127,7 @@ def save_async(self, pooling_output: torch.Tensor | None = None, request: Reques
transfer_manager=self,
pooling_output=pooling_output,
request=request,
async_chunk_config=async_chunk_config,
)

except Exception as e:
Expand Down Expand Up @@ -184,6 +190,11 @@ def _poll_single_request(self, req_id: str):
# req.num_computed_tokens = 0
new_ids = payload_data.get("code_predictor_codes", [])
req.prompt_token_ids = new_ids
# Pass additional fields (like left_context_size) to the request
# Only pass chunk context metadata in additional_information
req.additional_information = {}
if "left_context_size" in payload_data:
req.additional_information["left_context_size"] = payload_data["left_context_size"]
req.num_computed_tokens = 0

# Empty chunk with more data expected: keep polling.
Expand Down
7 changes: 7 additions & 0 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ class AsyncOmniEngineArgs(AsyncEngineArgs):
custom_process_next_stage_input_func: str | None = None
stage_connector_spec: dict[str, Any] = field(default_factory=dict)
async_chunk: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove async_chunk field here? It seems that it's not consistent with model.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep async_chunk as before. Take it out from async_chunk_config

async_chunk_config: dict[str, Any] = field(
default_factory=lambda: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nesting async_chunk (a boolean) inside async_chunk_config (a dict) alongside chunk_size/left_context_size (integers) feels like mixing concerns. OmniModelConfig keeps async_chunk: bool and async_chunk_config: dict as separate fields. Would it be cleaner to keep async_chunk as its own field here too and have async_chunk_config only hold chunk_size and left_context_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will keep async_chunk as before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

"chunk_size": 25,
"left_context_size": 25,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old async_chunk: bool = False field was replaced by this dict, but create_model_config on line 351 still does async_chunk=self.async_chunk. Since async_chunk is no longer a standalone field on AsyncOmniEngineArgs, this will raise AttributeError at runtime. (gcanlin flagged something similar.) It seems like async_chunk should either stay as a separate field or be derived here, e.g. self.async_chunk_config.get("async_chunk", False).

)
omni_kv_config: dict | None = None
worker_type: str | None = None

Expand Down Expand Up @@ -342,6 +348,7 @@ def create_model_config(self) -> OmniModelConfig:
logits_processors=self.logits_processors,
video_pruning_rate=video_pruning_rate,
io_processor_plugin=self.io_processor_plugin,
async_chunk_config=self.async_chunk_config,
# Omni-specific fields
stage_id=self.stage_id,
async_chunk=self.async_chunk,
Expand Down
21 changes: 17 additions & 4 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,14 @@ def forward(
codes = input_ids_flatten.reshape(1, 16, -1)

# Generate audio from codec codes
audio_tensors = self.generate_audio(codes, voice_type)
# Get left_context_size from runtime_additional_information (passed via kwargs)
# or additional_information parameter
left_context_size = None
if additional_information is not None:
left_context_size = additional_information.get("left_context_size")
else:
logger.debug("No additional_information provided to code2wav stage.")
audio_tensors = self.generate_audio(codes, voice_type, left_context_size=left_context_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning fires for every non-async-chunk call (or whenever additional_information is None), which could be very noisy in production. Is this intentional for debugging only, or should it be logger.debug instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will use logger.debug instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

return audio_tensors

Expand Down Expand Up @@ -458,13 +465,19 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -

# ==================== Audio Generation ====================

def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tensor]:
def generate_audio(
self,
code: torch.Tensor,
voice_type: str,
left_context_size: int | None = None,
) -> list[torch.Tensor]:
"""
Generate audio waveform from codec codes.

Args:
code: [8, T] - 8-layer RVQ codec codes
voice_type: Voice type (not used in Qwen3, kept for compatibility)
left_context_size: Context size for streaming decode (from async_chunk_config)

Returns:
audio_tensor: [1, waveform_len] - Audio waveform
Expand All @@ -487,10 +500,10 @@ def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tens
talker_codes = talker_codes.expand(1, 16, -1)

if self.vllm_config.model_config.async_chunk:
# Only use left_context_size from additional information
audio_tensors = self.code2wav.chunked_decode_streaming(
talker_codes,
chunk_size=25,
left_context_size=25,
left_context_size=left_context_size,
)
else:
# Use chunked decode for memory efficiency
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,28 @@ def chunked_decode(
def chunked_decode_streaming(
self,
codes: torch.Tensor,
chunk_size: int = 25,
left_context_size: int = 25,
left_context_size: int,
) -> list[torch.Tensor]:
"""
Decode long sequences in chunks to avoid OOM.

Uses overlapping chunks with left context to avoid boundary artifacts.

No longer need chunk size here, which is different from chunked_decode

Args:
codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes
chunk_size: Number of codec frames per chunk
left_context_size: Number of overlapping frames for context

Returns:
list[torch.Tensor]: Complete waveform decoded from the input
codes. For ``batch_size == 1``, this is a list containing a
single tensor with shape ``[1, waveform_len]``.
"""
if left_context_size is None:
logger.warning(
"left_context_size is None in chunked_decode_streaming; this may cause incorrect output shape."
)
# Decode chunk
wavs = []
batch_wav = self(codes)
Expand All @@ -243,14 +247,8 @@ def chunked_decode_streaming(
# Create one entry per batch so that each element is processed.
code_seq_lens = [codes.shape[-1]] * codes.shape[0]
for idx, code_seq_len in enumerate(code_seq_lens):
# TODO: need to optimize algorithms, current only support
# chunk_size = left_context_size = 25
if code_seq_len <= chunk_size:
context_size = 0
else:
context_size = left_context_size
# Remove context from output (context_size * total_upsample samples)
wav_chunk = batch_wav[idx, :, context_size * self.total_upsample : code_seq_len * self.total_upsample]
# Remove context from output (left_context_size * total_upsample samples)
wav_chunk = batch_wav[idx, :, left_context_size * self.total_upsample : code_seq_len * self.total_upsample]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In concurrent scenarios, the left_context_size of each request in code2wav may not be equal, so it is necessary to obtain the left_context_size of each request separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

wavs.append(wav_chunk)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunk_size is declared as a parameter of chunked_decode_streaming but is no longer used anywhere in the method body after this change. Is that intentional? If it is only there for API compatibility, might be worth a comment or removing it entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will removed chunk_size entirely

return wavs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ stage_args:
distributed_executor_backend: "mp"
hf_config_name: talker_config
custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
async_chunk_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config is under stage 1 (talker), which is where chunks are produced. But generate_audio in the code2wav stage also reads async_chunk_config from its own model_config (as a fallback when left_context_size is None). Since stage 2 does not have async_chunk_config in its YAML, it will silently use the default (25). Should async_chunk_config also be set under stage 2, or should the fallback in generate_audio be removed to make the additional_information path the only source of truth?

Copy link
Contributor Author

@LJH-LBJ LJH-LBJ Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, chunk_size is no longer needed here. I will remove the fallback in generate_audio and only use the left_context_size from additional_information. So we needn't set async_chunk_config in stage 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configurations for chunk_size and left_context_size should be consistent with those for TTS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

chunk_size: 25 # code2wav decode chunk size
left_context_size: 25 # code2wav left context size
engine_input_source: [0]
# final_output: true
# final_output_type: text
Expand Down
31 changes: 21 additions & 10 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.inputs import TextPrompt
from vllm.platforms import current_platform

from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import OmniChunkTransferAdapter
from vllm_omni.engine import OmniEngineCoreRequest
from vllm_omni.inputs.data import OmniTokensPrompt

Expand Down Expand Up @@ -87,6 +88,7 @@ def thinker2talker_async_chunk(
transfer_manager: Any,
pooling_output: dict[str, Any],
request: OmniEngineCoreRequest,
**kwargs,
) -> list[dict[str, Any]]:
"""
Process thinker outputs to create talker inputs.
Expand Down Expand Up @@ -206,16 +208,21 @@ def thinker2talker(


def talker2code2wav_async_chunk(
transfer_manager: Any,
transfer_manager: OmniChunkTransferAdapter,
pooling_output: dict[str, Any],
request: OmniEngineCoreRequest,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chunk_size configuration should be placed within the transfer_manager; do not pass it to the talker2code2wav_async_chunk function. Obtain the parameter from the transfer_manager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

):
"""
Pooling version.
"""
if "code_predictor_codes" not in pooling_output:
return None

async_chunk_config = kwargs.get("async_chunk_config", {})
chunk_size_config = async_chunk_config.get("chunk_size", 25)
left_context_size_config = async_chunk_config.get("left_context_size", 25)

code_predictor_codes = pooling_output["code_predictor_codes"]

if code_predictor_codes is None:
Expand All @@ -240,23 +247,27 @@ def talker2code2wav_async_chunk(
return None

request_id = request.external_req_id
chunk_size = left_context_size = 25
transfer_manager.code_prompt_token_ids[request_id].append(codec_codes)
length = len(transfer_manager.code_prompt_token_ids[request_id])
chunk_length = length % chunk_size
chunk_length = length % chunk_size_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when length == context_length (i.e., the first chunk)? min(0, left_context_size_config) gives 0, so end_index = 0 + chunk_size_config. That looks correct. But for the very last chunk when chunk_length != 0, context_length = chunk_length which could be small. Then left_context_size = min(length - chunk_length, left_context_size_config). Have you verified this against the table in the PR description for the final chunk case?

Copy link
Contributor Author

@LJH-LBJ LJH-LBJ Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when the first chunk gets length == context_length, end_index = min(context_length, left_context_size + context_length) = min(4, 0+4), i.e. context_length=4.

when the last chunk is smaller than chunk_size_config, left_context_size = min(length - chunk_length, left_context_size_config) = min(37-1, 25) = 25 in the table in PR.

I think it is correct now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it, that tracks. Thanks for walking through it.

if chunk_length != 0 and not request.is_finished():
return None

context_length = chunk_length if chunk_length != 0 else chunk_size
context_length = chunk_length if chunk_length != 0 else chunk_size_config
# ensure left context does not exceed available length
left_context_size = min(length - context_length, left_context_size_config)
end_index = min(length, left_context_size + context_length)

codes = (
torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:])
.transpose(0, 1)
.reshape(-1)
.tolist()
)

info = {
"code_predictor_codes": (
torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:])
.transpose(0, 1)
.reshape(-1)
.tolist()
),
"code_predictor_codes": codes,
"left_context_size": left_context_size,
"finished": torch.tensor(request.is_finished(), dtype=torch.bool),
}
return info
Expand Down