-
Notifications
You must be signed in to change notification settings - Fork 485
Make chunk_size and left_context_size configurable via YAML for async chunking #1423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ede288c
94766e1
b5596a5
bc21ac1
e330708
9fe559c
d2a1c28
3c73e35
3d23a72
ecd6e53
9303949
395c286
2be2eff
e11bc9a
38d74b6
c6bd3b0
af3d28d
e654c58
74ed57c
9ce0554
ea607bd
4eb97f0
602975d
f5889a0
4985891
c8b0e76
54f31ae
7e77fed
2af512e
fb04697
b1b441c
3692b96
1359e59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -213,18 +213,18 @@ def chunked_decode( | |
| def chunked_decode_streaming( | ||
| self, | ||
| codes: torch.Tensor, | ||
| chunk_size: int = 25, | ||
| left_context_size: int = 25, | ||
| left_context_size: list[int] | None = None, | ||
| seq_token_counts: list[int] | None = None, | ||
| ) -> list[torch.Tensor]: | ||
| """ | ||
| Decode long sequences in chunks to avoid OOM. | ||
|
|
||
| Uses overlapping chunks with left context to avoid boundary artifacts. | ||
|
|
||
| No longer need chunk size here, which is different from chunked_decode | ||
|
|
||
| Args: | ||
| codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes | ||
| chunk_size: Number of codec frames per chunk | ||
| left_context_size: Number of overlapping frames for context | ||
| seq_token_counts: Token count for each request in batch | ||
|
|
||
|
|
@@ -233,6 +233,13 @@ def chunked_decode_streaming( | |
| codes. For ``batch_size == 1``, this is a list containing a | ||
| single tensor with shape ``[1, waveform_len]``. | ||
| """ | ||
| if not (left_context_size and seq_token_counts and len(left_context_size) == len(seq_token_counts)): | ||
| logger.warning( | ||
| "chunked_decode_streaming: missing/invalid left_context_size or seq_token_counts; " | ||
| "defaulting to left_context_size=zeros(len=codes.shape[0])." | ||
| ) | ||
| left_context_size = [0] * codes.shape[0] | ||
| # Decode chunk | ||
| wavs = [] | ||
| batch_wav = self(codes) | ||
| if seq_token_counts is not None: | ||
|
|
@@ -241,14 +248,10 @@ def chunked_decode_streaming( | |
| # Fallback: assume all batch elements share the same sequence length. | ||
| code_seq_lens = [codes.shape[-1]] * codes.shape[0] | ||
| for idx, code_seq_len in enumerate(code_seq_lens): | ||
| # TODO: need to optimize algorithms, current only support | ||
| # chunk_size = left_context_size = 25 | ||
| if code_seq_len <= chunk_size: | ||
| context_size = 0 | ||
| else: | ||
| context_size = left_context_size | ||
| # Remove context from output (context_size * total_upsample samples) | ||
| wav_chunk = batch_wav[idx, :, context_size * self.total_upsample : code_seq_len * self.total_upsample] | ||
| # Remove context from output (left_context_size * total_upsample samples) | ||
| wav_chunk = batch_wav[ | ||
| idx, :, left_context_size[idx] * self.total_upsample : code_seq_len * self.total_upsample | ||
| ] | ||
| wavs.append(wav_chunk) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I will removed chunk_size entirely |
||
| return wavs | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,9 @@ stage_args: | |
| final_output: true | ||
| final_output_type: text | ||
| is_comprehension: true | ||
| # Use named connector to apply runtime.connectors.extra. | ||
| output_connectors: | ||
| to_stage_1: connector_of_shared_memory | ||
| default_sampling_params: | ||
| temperature: 0.4 | ||
| top_p: 0.9 | ||
|
|
@@ -60,6 +63,9 @@ stage_args: | |
| engine_input_source: [0] | ||
| # final_output: true | ||
| # final_output_type: text | ||
| # Distributed connector configuration | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have to make changes to yaml? @LJH-LBJ @amy-why-3459 @linyueqian
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the Qwen3-TTS side: existing YAMLs won't break since the stage input processor falls back to codec_chunk_frames=25 and codec_left_context_frames=25 by default. That said, the YAML changes are good to have. They wire up named connectors so these values are explicit and configurable rather than relying on hardcoded defaults, which is the whole point of this PR. |
||
| input_connectors: | ||
| from_stage_0: connector_of_shared_memory | ||
| default_sampling_params: | ||
| temperature: 0.9 | ||
| top_k: 50 | ||
|
|
@@ -99,3 +105,13 @@ stage_args: | |
| seed: 42 | ||
| detokenize: True | ||
| repetition_penalty: 1.1 | ||
|
|
||
| runtime: | ||
|
|
||
| connectors: | ||
| connector_of_shared_memory: | ||
| name: SharedMemoryConnector | ||
| extra: | ||
| # Align with Omni: small chunks with sufficient context overlap. | ||
| codec_chunk_frames: 25 # code2wav decode chunk size | ||
| codec_left_context_frames: 25 # code2wav left context size | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -220,6 +220,12 @@ def talker2code2wav_async_chunk( | |
| if "code_predictor_codes" not in pooling_output: | ||
| return None | ||
|
|
||
| connector = getattr(transfer_manager, "connector", None) | ||
| raw_cfg = getattr(connector, "config", {}) or {} | ||
| cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} | ||
| chunk_size_config = int(cfg.get("codec_chunk_frames", 25)) | ||
| left_context_size_config = int(cfg.get("codec_left_context_frames", 25)) | ||
|
|
||
| code_predictor_codes = pooling_output["code_predictor_codes"] | ||
|
|
||
| if code_predictor_codes is None: | ||
|
|
@@ -244,23 +250,27 @@ def talker2code2wav_async_chunk( | |
| return None | ||
|
|
||
| request_id = request.external_req_id | ||
| chunk_size = left_context_size = 25 | ||
| transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) | ||
| length = len(transfer_manager.code_prompt_token_ids[request_id]) | ||
| chunk_length = length % chunk_size | ||
| chunk_length = length % chunk_size_config | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when the first chunk gets length == context_length, end_index = min(context_length, left_context_size + context_length) = min(4, 0+4), i.e. context_length=4. when the last chunk is smaller than chunk_size_config, left_context_size = min(length - chunk_length, left_context_size_config) = min(37-1, 25) = 25 in the table in PR. I think it is correct now.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it, that tracks. Thanks for walking through it. |
||
| if chunk_length != 0 and not is_finished: | ||
| return None | ||
|
|
||
| context_length = chunk_length if chunk_length != 0 else chunk_size | ||
| context_length = chunk_length if chunk_length != 0 else chunk_size_config | ||
| # ensure left context does not exceed available length | ||
| left_context_size = max(0, min(length - context_length, left_context_size_config)) | ||
| end_index = min(length, left_context_size + context_length) | ||
|
|
||
| codes = ( | ||
| torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) | ||
| .transpose(0, 1) | ||
| .reshape(-1) | ||
| .tolist() | ||
| ) | ||
|
|
||
| info = { | ||
| "code_predictor_codes": ( | ||
| torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]) | ||
| .transpose(0, 1) | ||
| .reshape(-1) | ||
| .tolist() | ||
| ), | ||
| "code_predictor_codes": codes, | ||
| "left_context_size": left_context_size, | ||
| "finished": torch.tensor(is_finished, dtype=torch.bool), | ||
| } | ||
| return info | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1151,12 +1151,18 @@ def _preprocess( | |
| # Prefill: overlay prompt_embeds and collect additional_information | ||
| self._collect_additional_information_for_prefill(num_scheduled_tokens_np) | ||
|
|
||
| # Keep per-request additional_information in sync for both new and | ||
| # cached requests. This is required for stages without preprocess | ||
| # (e.g., code2wav) so runtime_additional_information can be refreshed | ||
| # from scheduler cached infos on every step. | ||
| if hasattr(self.model, "has_preprocess") or hasattr(self.model, "enable_update_additional_information"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why check
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the original logic, when has_preprocess=True, update_additional_information can be called. Now I don’t want to change the original logic; I just want the code2wav stage to also be able to update additional_information.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough, makes sense. |
||
| if self.vllm_config.model_config.async_chunk: | ||
| self._update_additional_information(scheduler_output) | ||
|
|
||
| if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: | ||
| # Overlay custom prompt_embeds per request for the prompt portion; | ||
| # collect additional_information (tensor/list) for prefill portion only | ||
| decode_req_ids = [] | ||
| if self.vllm_config.model_config.async_chunk: | ||
| self._update_additional_information(scheduler_output) | ||
| for req_index, req_id in enumerate(self.input_batch.req_ids): | ||
| req_state = self.requests.get(req_id) | ||
| req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning fires for every non-async-chunk call (or whenever
additional_informationis None), which could be very noisy in production. Is this intentional for debugging only, or should it belogger.debuginstead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will use logger.debug instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!