-
Notifications
You must be signed in to change notification settings - Fork 473
Make chunk_size and left_context_size configurable via YAML for async chunking #1423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
ede288c
94766e1
b5596a5
bc21ac1
e330708
9fe559c
d2a1c28
3c73e35
3d23a72
ecd6e53
9303949
395c286
2be2eff
e11bc9a
38d74b6
c6bd3b0
af3d28d
e654c58
74ed57c
9ce0554
ea607bd
4eb97f0
602975d
f5889a0
4985891
c8b0e76
54f31ae
7e77fed
2af512e
fb04697
b1b441c
3692b96
1359e59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -229,6 +229,12 @@ class AsyncOmniEngineArgs(AsyncEngineArgs): | |
| custom_process_next_stage_input_func: str | None = None | ||
| stage_connector_spec: dict[str, Any] = field(default_factory=dict) | ||
| async_chunk: bool = False | ||
| async_chunk_config: dict[str, Any] = field( | ||
| default_factory=lambda: { | ||
|
||
| "chunk_size": 25, | ||
| "left_context_size": 25, | ||
| } | ||
|
||
| ) | ||
| omni_kv_config: dict | None = None | ||
| worker_type: str | None = None | ||
|
|
||
|
|
@@ -342,6 +348,7 @@ def create_model_config(self) -> OmniModelConfig: | |
| logits_processors=self.logits_processors, | ||
| video_pruning_rate=video_pruning_rate, | ||
| io_processor_plugin=self.io_processor_plugin, | ||
| async_chunk_config=self.async_chunk_config, | ||
| # Omni-specific fields | ||
| stage_id=self.stage_id, | ||
| async_chunk=self.async_chunk, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -386,7 +386,14 @@ def forward( | |
| codes = input_ids_flatten.reshape(1, 16, -1) | ||
|
|
||
| # Generate audio from codec codes | ||
| audio_tensors = self.generate_audio(codes, voice_type) | ||
| # Get left_context_size from runtime_additional_information (passed via kwargs) | ||
| # or additional_information parameter | ||
| left_context_size = None | ||
| if additional_information is not None: | ||
| left_context_size = additional_information.get("left_context_size") | ||
| else: | ||
| logger.debug("No additional_information provided to code2wav stage.") | ||
| audio_tensors = self.generate_audio(codes, voice_type, left_context_size=left_context_size) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This warning fires for every non-async-chunk call (or whenever
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will use logger.debug instead
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
| return audio_tensors | ||
|
|
||
|
|
@@ -458,13 +465,19 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) - | |
|
|
||
| # ==================== Audio Generation ==================== | ||
|
|
||
| def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tensor]: | ||
| def generate_audio( | ||
| self, | ||
| code: torch.Tensor, | ||
| voice_type: str, | ||
| left_context_size: int | None = None, | ||
| ) -> list[torch.Tensor]: | ||
| """ | ||
| Generate audio waveform from codec codes. | ||
|
|
||
| Args: | ||
| code: [8, T] - 8-layer RVQ codec codes | ||
| voice_type: Voice type (not used in Qwen3, kept for compatibility) | ||
| left_context_size: Context size for streaming decode (from async_chunk_config) | ||
|
|
||
| Returns: | ||
| audio_tensor: [1, waveform_len] - Audio waveform | ||
|
|
@@ -487,10 +500,10 @@ def generate_audio(self, code: torch.Tensor, voice_type: str) -> list[torch.Tens | |
| talker_codes = talker_codes.expand(1, 16, -1) | ||
|
|
||
| if self.vllm_config.model_config.async_chunk: | ||
| # Only use left_context_size from additional information | ||
| audio_tensors = self.code2wav.chunked_decode_streaming( | ||
| talker_codes, | ||
| chunk_size=25, | ||
| left_context_size=25, | ||
| left_context_size=left_context_size, | ||
| ) | ||
| else: | ||
| # Use chunked decode for memory efficiency | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -214,24 +214,28 @@ def chunked_decode( | |
| def chunked_decode_streaming( | ||
| self, | ||
| codes: torch.Tensor, | ||
| chunk_size: int = 25, | ||
| left_context_size: int = 25, | ||
| left_context_size: int, | ||
| ) -> list[torch.Tensor]: | ||
| """ | ||
| Decode long sequences in chunks to avoid OOM. | ||
|
|
||
| Uses overlapping chunks with left context to avoid boundary artifacts. | ||
|
|
||
| No longer need chunk size here, which is different from chunked_decode | ||
|
|
||
| Args: | ||
| codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes | ||
| chunk_size: Number of codec frames per chunk | ||
| left_context_size: Number of overlapping frames for context | ||
|
|
||
| Returns: | ||
| list[torch.Tensor]: Complete waveform decoded from the input | ||
| codes. For ``batch_size == 1``, this is a list containing a | ||
| single tensor with shape ``[1, waveform_len]``. | ||
| """ | ||
| if left_context_size is None: | ||
| logger.warning( | ||
| "left_context_size is None in chunked_decode_streaming; this may cause incorrect output shape." | ||
| ) | ||
| # Decode chunk | ||
| wavs = [] | ||
| batch_wav = self(codes) | ||
|
|
@@ -243,14 +247,8 @@ def chunked_decode_streaming( | |
| # Create one entry per batch so that each element is processed. | ||
| code_seq_lens = [codes.shape[-1]] * codes.shape[0] | ||
| for idx, code_seq_len in enumerate(code_seq_lens): | ||
| # TODO: need to optimize algorithms, current only support | ||
| # chunk_size = left_context_size = 25 | ||
| if code_seq_len <= chunk_size: | ||
| context_size = 0 | ||
| else: | ||
| context_size = left_context_size | ||
| # Remove context from output (context_size * total_upsample samples) | ||
| wav_chunk = batch_wav[idx, :, context_size * self.total_upsample : code_seq_len * self.total_upsample] | ||
| # Remove context from output (left_context_size * total_upsample samples) | ||
| wav_chunk = batch_wav[idx, :, left_context_size * self.total_upsample : code_seq_len * self.total_upsample] | ||
|
||
| wavs.append(wav_chunk) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I will removed chunk_size entirely |
||
| return wavs | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,9 @@ stage_args: | |
| distributed_executor_backend: "mp" | ||
| hf_config_name: talker_config | ||
| custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk | ||
| async_chunk_config: | ||
|
||
| chunk_size: 25 # code2wav decode chunk size | ||
| left_context_size: 25 # code2wav left context size | ||
| engine_input_source: [0] | ||
| # final_output: true | ||
| # final_output_type: text | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from vllm.inputs import TextPrompt | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import OmniChunkTransferAdapter | ||
| from vllm_omni.engine import OmniEngineCoreRequest | ||
| from vllm_omni.inputs.data import OmniTokensPrompt | ||
|
|
||
|
|
@@ -87,6 +88,7 @@ def thinker2talker_async_chunk( | |
| transfer_manager: Any, | ||
| pooling_output: dict[str, Any], | ||
| request: OmniEngineCoreRequest, | ||
| **kwargs, | ||
| ) -> list[dict[str, Any]]: | ||
| """ | ||
| Process thinker outputs to create talker inputs. | ||
|
|
@@ -206,16 +208,21 @@ def thinker2talker( | |
|
|
||
|
|
||
| def talker2code2wav_async_chunk( | ||
| transfer_manager: Any, | ||
| transfer_manager: OmniChunkTransferAdapter, | ||
| pooling_output: dict[str, Any], | ||
| request: OmniEngineCoreRequest, | ||
| **kwargs, | ||
|
||
| ): | ||
| """ | ||
| Pooling version. | ||
| """ | ||
| if "code_predictor_codes" not in pooling_output: | ||
| return None | ||
|
|
||
| async_chunk_config = kwargs.get("async_chunk_config", {}) | ||
| chunk_size_config = async_chunk_config.get("chunk_size", 25) | ||
| left_context_size_config = async_chunk_config.get("left_context_size", 25) | ||
|
|
||
| code_predictor_codes = pooling_output["code_predictor_codes"] | ||
|
|
||
| if code_predictor_codes is None: | ||
|
|
@@ -240,23 +247,27 @@ def talker2code2wav_async_chunk( | |
| return None | ||
|
|
||
| request_id = request.external_req_id | ||
| chunk_size = left_context_size = 25 | ||
| transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) | ||
| length = len(transfer_manager.code_prompt_token_ids[request_id]) | ||
| chunk_length = length % chunk_size | ||
| chunk_length = length % chunk_size_config | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when the first chunk gets length == context_length, end_index = min(context_length, left_context_size + context_length) = min(4, 0+4), i.e. context_length=4. when the last chunk is smaller than chunk_size_config, left_context_size = min(length - chunk_length, left_context_size_config) = min(37-1, 25) = 25 in the table in PR. I think it is correct now.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it, that tracks. Thanks for walking through it. |
||
| if chunk_length != 0 and not request.is_finished(): | ||
| return None | ||
|
|
||
| context_length = chunk_length if chunk_length != 0 else chunk_size | ||
| context_length = chunk_length if chunk_length != 0 else chunk_size_config | ||
| # ensure left context does not exceed available length | ||
| left_context_size = min(length - context_length, left_context_size_config) | ||
| end_index = min(length, left_context_size + context_length) | ||
|
|
||
| codes = ( | ||
| torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) | ||
| .transpose(0, 1) | ||
| .reshape(-1) | ||
| .tolist() | ||
| ) | ||
|
|
||
| info = { | ||
| "code_predictor_codes": ( | ||
| torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) | ||
| .transpose(0, 1) | ||
| .reshape(-1) | ||
| .tolist() | ||
| ), | ||
| "code_predictor_codes": codes, | ||
| "left_context_size": left_context_size, | ||
| "finished": torch.tensor(request.is_finished(), dtype=torch.bool), | ||
| } | ||
| return info | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove
async_chunkfield here? It seems that it's not consistent withmodel.pyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will keep async_chunk as before. Take it out from async_chunk_config