diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 75ffa80831..96f9f24021 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -49,6 +49,7 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_encoder_inputs: dict[str, list[int]] = {} cached_prompt_token_ids: dict[str, list[int]] = {} + cached_additional_information: dict[str, dict | None] = {} # Temporary queue: preserve waiting order, do not disturb non-diffusion requests skipped_waiting_requests = create_request_queue(self.policy) @@ -105,6 +106,7 @@ def schedule(self) -> SchedulerOutput: req_to_new_blocks[request.request_id] = new_blocks num_scheduled_tokens[request.request_id] = num_new_tokens cached_prompt_token_ids[request.request_id] = request.prompt_token_ids + cached_additional_information[request.request_id] = getattr(request, "additional_information", None) token_budget -= num_new_tokens scheduled_running_reqs.append(request) req_index += 1 @@ -225,6 +227,7 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens=cached_reqs_data.num_computed_tokens, num_output_tokens=cached_reqs_data.num_output_tokens, prompt_token_ids=cached_prompt_token_ids, + additional_information=cached_additional_information, ) total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index b86a97534f..f09e75c2bd 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -71,6 +71,7 @@ class OmniCachedRequestData(CachedRequestData): """ prompt_token_ids: dict[str, list[int]] + additional_information: dict[str, dict | None] @dataclass diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index a7f069bebc..06aecec0cd 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -153,6 +153,11 @@ def _poll_single_request(self, request: Request): new_ids = payload_data.get("code_predictor_codes", []) request.prompt_token_ids = new_ids + # Pass additional fields (like left_context_size) to the request + # Only pass chunk context metadata in additional_information + request.additional_information = {} + if "left_context_size" in payload_data: + request.additional_information["left_context_size"] = payload_data["left_context_size"] request.num_computed_tokens = 0 # Empty chunk with more data expected: keep polling. diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 46b07f9deb..bd333d1b24 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from functools import cached_property +from typing import Any import torch import torch.nn as nn @@ -160,6 +161,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.requires_raw_input_tokens = True elif self.model_stage == "code2wav": + self.enable_update_additional_information = True self.thinker = None self.talker = None # Initialize code2wav (codec codes → audio waveform) @@ -254,7 +256,7 @@ def forward( codec: torch.Tensor | None = None, sampling_metadata: SamplingMetadata | None = None, logits_index: int | None = None, - additional_information: dict[str, object] | None = None, + runtime_additional_information: list[dict[str, Any]] | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors | OmniOutput: """ @@ -359,7 +361,15 @@ def forward( codes = input_ids_flatten.reshape(1, 16, -1) # Generate audio from codec codes - audio_tensors = self.generate_audio(codes, voice_type, seq_token_counts) + # Get every request's left_context_size from runtime_additional_information (passed via kwargs) + left_context_size = [] + if runtime_additional_information is not None: + for info in runtime_additional_information: + if "left_context_size" in info: + left_context_size.append(info["left_context_size"]) + else: + logger.debug("No additional_information provided to code2wav stage.") + audio_tensors = self.generate_audio(codes, voice_type, left_context_size, seq_token_counts) return audio_tensors @@ -435,6 +445,7 @@ def generate_audio( self, code: torch.Tensor, voice_type: str, + left_context_size: list[int] | None = None, seq_token_counts: list[int] | None = None, ) -> list[torch.Tensor]: """ @@ -443,6 +454,7 @@ def generate_audio( Args: code: [batch, num_quantizers, T] - RVQ codec codes voice_type: Voice type (not used in Qwen3, kept for compatibility) + left_context_size: Left context size for streaming decode seq_token_counts: Token count for each request in batch Returns: @@ -466,10 +478,10 @@ def generate_audio( talker_codes = talker_codes.expand(1, 16, -1) if self.vllm_config.model_config.async_chunk: + # Only use left_context_size from additional information audio_tensors = self.code2wav.chunked_decode_streaming( talker_codes, - chunk_size=25, - left_context_size=25, + left_context_size=left_context_size, seq_token_counts=seq_token_counts, ) else: diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py index 54bc1c1c59..41eeec5e7f 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py @@ -213,8 +213,7 @@ def chunked_decode( def chunked_decode_streaming( self, codes: torch.Tensor, - chunk_size: int = 25, - left_context_size: int = 25, + left_context_size: list[int] | None = None, seq_token_counts: list[int] | None = None, ) -> list[torch.Tensor]: """ @@ -222,9 +221,10 @@ def chunked_decode_streaming( Uses overlapping chunks with left context to avoid boundary artifacts. + No longer need chunk size here, which is different from chunked_decode + Args: codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes - chunk_size: Number of codec frames per chunk left_context_size: Number of overlapping frames for context seq_token_counts: Token count for each request in batch @@ -233,6 +233,13 @@ def chunked_decode_streaming( codes. For ``batch_size == 1``, this is a list containing a single tensor with shape ``[1, waveform_len]``. """ + if not (left_context_size and seq_token_counts and len(left_context_size) == len(seq_token_counts)): + logger.warning( + "chunked_decode_streaming: missing/invalid left_context_size or seq_token_counts; " + "defaulting to left_context_size=zeros(len=codes.shape[0])." + ) + left_context_size = [0] * codes.shape[0] + # Decode chunk wavs = [] batch_wav = self(codes) if seq_token_counts is not None: @@ -241,14 +248,10 @@ def chunked_decode_streaming( # Fallback: assume all batch elements share the same sequence length. code_seq_lens = [codes.shape[-1]] * codes.shape[0] for idx, code_seq_len in enumerate(code_seq_lens): - # TODO: need to optimize algorithms, current only support - # chunk_size = left_context_size = 25 - if code_seq_len <= chunk_size: - context_size = 0 - else: - context_size = left_context_size - # Remove context from output (context_size * total_upsample samples) - wav_chunk = batch_wav[idx, :, context_size * self.total_upsample : code_seq_len * self.total_upsample] + # Remove context from output (left_context_size * total_upsample samples) + wav_chunk = batch_wav[ + idx, :, left_context_size[idx] * self.total_upsample : code_seq_len * self.total_upsample + ] wavs.append(wav_chunk) return wavs diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index a87027de18..851455bda6 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -33,6 +33,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.have_multimodal_outputs = True self.has_preprocess = False self.has_postprocess = False + self.enable_update_additional_information = True # Generation-only stage (no logits / sampling). self.requires_raw_input_tokens = True @@ -146,6 +147,7 @@ def forward( positions: torch.Tensor | None = None, intermediate_tensors: Any = None, inputs_embeds: torch.Tensor | None = None, + runtime_additional_information: list[dict[str, Any]] | None = None, **kwargs: Any, ) -> OmniOutput: """Decode codec codes into audio waveform. @@ -177,18 +179,28 @@ def forward( ids = input_ids.reshape(-1).to(dtype=torch.long) request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts")) - # Parse each request: extract ctx_frames, validate, reshape codes. - # input_ids layout per request: [codec_context_frames, *flat_codes] + # Parse each request: extract left_context_size, validate, reshape codes. + # input_ids layout per request: [*flat_codes] # where flat_codes is codebook-major [q*F]. - parsed = [] # (ctx_frames, actual_frames) + parsed = [] # (left_context_size, actual_frames) valid_codes = [] valid_indices = [] + # Get left_context_size from runtime_additional_information (passed via kwargs) + left_context_size = [0] * len(request_ids_list) + if runtime_additional_information is not None: + for i, info in enumerate(runtime_additional_information): + if i >= len(left_context_size): + break + if "left_context_size" in info: + left_context_size[i] = info["left_context_size"] + else: + logger.debug("No additional_information provided to code2wav stage.") for i, req_ids in enumerate(request_ids_list): - if req_ids.numel() < 2: + if req_ids.numel() < 1: parsed.append((0, 0)) continue - ctx_frames = int(req_ids[0].item()) - flat = req_ids[1:] + ctx_frames = left_context_size[i] + flat = req_ids n = flat.numel() # Warmup / dummy_run: not divisible by num_quantizers. if n == 0 or n % q != 0: diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml index da996494ea..8d8d00d6bc 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml @@ -29,6 +29,9 @@ stage_args: final_output: true final_output_type: text is_comprehension: true + # Use named connector to apply runtime.connectors.extra. + output_connectors: + to_stage_1: connector_of_shared_memory default_sampling_params: temperature: 0.4 top_p: 0.9 @@ -60,6 +63,9 @@ stage_args: engine_input_source: [0] # final_output: true # final_output_type: text + # Distributed connector configuration + input_connectors: + from_stage_0: connector_of_shared_memory default_sampling_params: temperature: 0.9 top_k: 50 @@ -99,3 +105,13 @@ stage_args: seed: 42 detokenize: True repetition_penalty: 1.1 + +runtime: + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + # Align with Omni: small chunks with sufficient context overlap. + codec_chunk_frames: 25 # code2wav decode chunk size + codec_left_context_frames: 25 # code2wav left context size diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 3a42159a8f..31a5ddd47f 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -220,6 +220,12 @@ def talker2code2wav_async_chunk( if "code_predictor_codes" not in pooling_output: return None + connector = getattr(transfer_manager, "connector", None) + raw_cfg = getattr(connector, "config", {}) or {} + cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} + chunk_size_config = int(cfg.get("codec_chunk_frames", 25)) + left_context_size_config = int(cfg.get("codec_left_context_frames", 25)) + code_predictor_codes = pooling_output["code_predictor_codes"] if code_predictor_codes is None: @@ -244,23 +250,27 @@ def talker2code2wav_async_chunk( return None request_id = request.external_req_id - chunk_size = left_context_size = 25 transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) length = len(transfer_manager.code_prompt_token_ids[request_id]) - chunk_length = length % chunk_size + chunk_length = length % chunk_size_config if chunk_length != 0 and not is_finished: return None - context_length = chunk_length if chunk_length != 0 else chunk_size + context_length = chunk_length if chunk_length != 0 else chunk_size_config + # ensure left context does not exceed available length + left_context_size = max(0, min(length - context_length, left_context_size_config)) end_index = min(length, left_context_size + context_length) + codes = ( + torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) + .transpose(0, 1) + .reshape(-1) + .tolist() + ) + info = { - "code_predictor_codes": ( - torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) - .transpose(0, 1) - .reshape(-1) - .tolist() - ), + "code_predictor_codes": codes, + "left_context_size": left_context_size, "finished": torch.tensor(is_finished, dtype=torch.bool), } return info diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 72e17bf4f3..8c21052e9e 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -40,12 +40,12 @@ def talker2code2wav_async_chunk( connector = getattr(transfer_manager, "connector", None) raw_cfg = getattr(connector, "config", {}) or {} cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} - chunk_size = int(cfg.get("codec_chunk_frames", 25)) - left_context_size = int(cfg.get("codec_left_context_frames", 25)) - if chunk_size <= 0 or left_context_size < 0: + chunk_size_config = int(cfg.get("codec_chunk_frames", 25)) + left_context_size_config = int(cfg.get("codec_left_context_frames", 25)) + if chunk_size_config <= 0 or left_context_size_config < 0: raise ValueError( - f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, " - f"codec_left_context_frames={left_context_size}" + f"Invalid codec chunk config: codec_chunk_frames={chunk_size_config}, " + f"codec_left_context_frames={left_context_size_config}" ) length = len(transfer_manager.code_prompt_token_ids[request_id]) @@ -59,22 +59,23 @@ def talker2code2wav_async_chunk( } return None - chunk_length = length % chunk_size + chunk_length = length % chunk_size_config if chunk_length != 0 and not finished: return None - context_length = chunk_length if chunk_length != 0 else chunk_size - end_index = min(length, left_context_size + context_length) - ctx_frames = max(0, int(end_index - context_length)) + context_length = chunk_length if chunk_length != 0 else chunk_size_config + end_index = min(length, left_context_size_config + context_length) + left_context_size = max(0, int(end_index - context_length)) window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:] # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() - # Build final prompt_token_ids with ctx_frames header for Qwen3-TTS Code2Wav. - # The model expects input_ids layout: [ctx_frames, *flat_codes]. + # Build final prompt_token_ids and left_context_size header for Qwen3-TTS Code2Wav. + # The model expects input_ids layout: [*flat_codes]. return { - "code_predictor_codes": [int(ctx_frames)] + code_predictor_codes, + "code_predictor_codes": code_predictor_codes, + "left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool), } diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 9dd512333a..7cb3e57dd1 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1151,12 +1151,18 @@ def _preprocess( # Prefill: overlay prompt_embeds and collect additional_information self._collect_additional_information_for_prefill(num_scheduled_tokens_np) + # Keep per-request additional_information in sync for both new and + # cached requests. This is required for stages without preprocess + # (e.g., code2wav) so runtime_additional_information can be refreshed + # from scheduler cached infos on every step. + if hasattr(self.model, "has_preprocess") or hasattr(self.model, "enable_update_additional_information"): + if self.vllm_config.model_config.async_chunk: + self._update_additional_information(scheduler_output) + if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] - if self.vllm_config.model_config.async_chunk: - self._update_additional_information(scheduler_output) for req_index, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests.get(req_id) req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None