Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/feature/async_chunk_design.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The `async_chunk` feature enables asynchronous, chunked processing of data acros

For qwen3-omni:
- **Thinker → Talker**: Per decode step (typically chunk_size=1)
- **Talker → Code2Wav**: Accumulated to code2wav chunk_size(default=25, currently only support default, will support chunk_size soon) before sending
- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. Set `initial_codec_chunk_frames` to emit smaller chunks during warmup for reduced TTFA
- **Code2Wav**: Streaming decode with code2wav chunk_size

With `async_chunk`:
Expand Down
3 changes: 2 additions & 1 deletion docs/user_guide/examples/offline_inference/qwen3_tts.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ Add `--streaming` to stream audio chunks progressively via `AsyncOmni` (requires
python end2end.py --query-type CustomVoice --streaming --output-dir /tmp/out_stream
```

Each 25-frame Code2Wav chunk is logged as it arrives. The final WAV file is written once generation
Each Code2Wav chunk is logged as it arrives (default 25 frames; configurable via `codec_chunk_frames`
and `initial_codec_chunk_frames` in the stage config). The final WAV file is written once generation
completes. This demonstrates that audio data is available progressively rather than only at the end.

> **Note:** Streaming uses `AsyncOmni` internally. The non-streaming path (`Omni`) is unchanged.
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/examples/online_serving/qwen3_tts.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ Returns binary audio data with appropriate `Content-Type` header (e.g., `audio/w
## Streaming

Set `stream=true` with `response_format="pcm"` to receive raw PCM audio chunks as they are decoded
(one chunk per 25-frame Code2Wav window):
(one chunk per Code2Wav window, default 25 frames; configurable in the stage config):

```bash
curl -X POST http://localhost:8091/v1/audio/speech \
Expand Down
3 changes: 2 additions & 1 deletion examples/offline_inference/qwen3_tts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ Add `--streaming` to stream audio chunks progressively via `AsyncOmni` (requires
python end2end.py --query-type CustomVoice --streaming --output-dir /tmp/out_stream
```

Each 25-frame Code2Wav chunk is logged as it arrives. The final WAV file is written once generation
Each Code2Wav chunk is logged as it arrives (default 25 frames; configurable via `codec_chunk_frames`
and `initial_codec_chunk_frames` in the stage config). The final WAV file is written once generation
completes. This demonstrates that audio data is available progressively rather than only at the end.

> **Note:** Streaming uses `AsyncOmni` internally. The non-streaming path (`Omni`) is unchanged.
Expand Down
17 changes: 16 additions & 1 deletion examples/offline_inference/qwen3_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import logging
import os
import time
from typing import Any, NamedTuple

import soundfile as sf
Expand Down Expand Up @@ -337,13 +338,27 @@ async def main_streaming(args):

for i, prompt in enumerate(inputs):
request_id = str(i)
t_start = time.perf_counter()
t_prev = t_start
chunk_idx = 0
async for stage_output in omni.generate(prompt, request_id=request_id):
mm = stage_output.request_output.outputs[0].multimodal_output
if not stage_output.finished:
t_now = time.perf_counter()
audio = mm.get("audio")
n = len(audio) if isinstance(audio, list) else (0 if audio is None else 1)
logger.info(f"Request {request_id}: received chunk {n}")
dt_ms = (t_now - t_prev) * 1000
ttfa_ms = (t_now - t_start) * 1000
if chunk_idx == 0:
logger.info(f"Request {request_id}: chunk {chunk_idx} samples={n} TTFA={ttfa_ms:.1f}ms")
else:
logger.info(f"Request {request_id}: chunk {chunk_idx} samples={n} inter_chunk={dt_ms:.1f}ms")
t_prev = t_now
chunk_idx += 1
else:
t_end = time.perf_counter()
total_ms = (t_end - t_start) * 1000
logger.info(f"Request {request_id}: done total={total_ms:.1f}ms chunks={chunk_idx}")
_save_wav(output_dir, request_id, mm)


Expand Down
2 changes: 1 addition & 1 deletion examples/online_serving/qwen3_tts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Returns binary audio data with appropriate `Content-Type` header (e.g., `audio/w
## Streaming

Set `stream=true` with `response_format="pcm"` to receive raw PCM audio chunks as they are decoded
(one chunk per 25-frame Code2Wav window):
(one chunk per Code2Wav window, default 25 frames; configurable in the stage config):

```bash
curl -X POST http://localhost:8091/v1/audio/speech \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,83 @@
from collections import defaultdict
from types import SimpleNamespace

import pytest
import torch

from vllm_omni.model_executor.stage_input_processors.qwen3_tts import talker2code2wav_async_chunk

_FRAME = [1, 2, 3, 4] # 4-codebook frame
_Q = len(_FRAME) # num quantizers

def _req(external_req_id: str, *, finished: bool):

def _req(rid: str, *, finished: bool):
return SimpleNamespace(external_req_id=rid, is_finished=lambda: finished)


def _tm(*, chunk_frames=25, left_context=25, initial_chunk=0):
return SimpleNamespace(
external_req_id=external_req_id,
is_finished=lambda: finished,
code_prompt_token_ids=defaultdict(list),
put_req_chunk=defaultdict(int),
connector=SimpleNamespace(
config={
"extra": {
"codec_chunk_frames": chunk_frames,
"codec_left_context_frames": left_context,
"initial_codec_chunk_frames": initial_chunk,
}
}
),
)


def test_talker2code2wav_async_chunk_does_not_emit_empty_chunk_when_not_finished():
transfer_manager = SimpleNamespace(
code_prompt_token_ids=defaultdict(list),
connector=SimpleNamespace(config={"extra": {"codec_chunk_frames": 25, "codec_left_context_frames": 25}}),
def _call(tm, rid, *, n_frames, put_req=0, finished=False):
"""Feed n_frames into transfer_manager and call the gate function."""
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(n_frames)]
tm.put_req_chunk[rid] = put_req
return talker2code2wav_async_chunk(
transfer_manager=tm,
pooling_output={"audio_codes": torch.zeros((0,))},
request=_req(rid, finished=finished),
is_finished=finished,
)


def test_does_not_emit_empty_chunk_when_not_finished():
tm = _tm()
request = _req("rid-empty", finished=False)

payload = talker2code2wav_async_chunk(
transfer_manager=transfer_manager,
transfer_manager=tm,
pooling_output={"audio_codes": torch.zeros((0,))},
request=request,
)

assert payload is None


def test_talker2code2wav_async_chunk_flushes_tail_when_finished_without_pooler_output():
transfer_manager = SimpleNamespace(
code_prompt_token_ids=defaultdict(list),
connector=SimpleNamespace(config={"extra": {"codec_chunk_frames": 25, "codec_left_context_frames": 25}}),
)
request_id = "rid-tail"
transfer_manager.code_prompt_token_ids[request_id] = [[1, 2, 3, 4] for _ in range(24)]
request = _req(request_id, finished=True)
def test_flushes_tail_when_finished_without_pooler_output():
tm = _tm()
rid = "rid-tail"
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(24)]
request = _req(rid, finished=True)

payload = talker2code2wav_async_chunk(
transfer_manager=transfer_manager,
pooling_output=None, # e.g. EOS step with no audio_codes
transfer_manager=tm,
pooling_output=None,
request=request,
)

assert payload is not None
assert payload["finished"].item() is True
# ctx_frames header + flat codes
assert len(payload["code_predictor_codes"]) == 1 + 4 * 24
assert len(payload["code_predictor_codes"]) == 1 + _Q * 24


def test_talker2code2wav_async_chunk_emits_eof_marker_when_finished_with_no_frames():
transfer_manager = SimpleNamespace(
code_prompt_token_ids=defaultdict(list),
connector=SimpleNamespace(config={"extra": {"codec_chunk_frames": 25, "codec_left_context_frames": 25}}),
)
def test_emits_eof_marker_when_finished_with_no_frames():
tm = _tm()
request = _req("rid-eof", finished=True)

payload = talker2code2wav_async_chunk(
transfer_manager=transfer_manager,
transfer_manager=tm,
pooling_output=None,
request=request,
)
Expand All @@ -70,3 +89,45 @@ def test_talker2code2wav_async_chunk_emits_eof_marker_when_finished_with_no_fram
"code_predictor_codes": [],
"finished": torch.tensor(True, dtype=torch.bool),
}


_CASES = [
# Normal path (initial=0): emit at chunk_size boundaries
((25, 25, 0), (24, 0, False), None),
((25, 25, 0), (25, 0, False), (0, 25)),
# Warmup: hold, first emit, second emit
((25, 25, 10), (9, 0, False), None),
((25, 25, 10), (10, 0, False), (0, 10)),
((25, 25, 10), (20, 1, False), (10, 20)),
# Non-divisible: holds at chunk boundary
((25, 25, 12), (25, 2, False), None),
# Normal phase: offset by warmup coverage (chunk//initial * initial)
((25, 25, 10), (45, 2, False), (20, 45)),
# Second normal emit (put_req includes normal emissions, offset must stay stable)
((25, 25, 10), (70, 3, False), (25, 50)),
# initial >= chunk clamps to chunk_size (behaves as normal)
((25, 25, 30), (25, 0, False), (0, 25)),
# finished=True flushes warmup tail
((25, 25, 10), (5, 0, True), (0, 5)),
# finished=True flushes non-divisible warmup residual
((25, 25, 12), (25, 2, True), (24, 25)),
# finished=True flushes normal phase tail
((25, 25, 10), (30, 2, True), (20, 30)),
]


@pytest.mark.parametrize("config, state, expected", _CASES)
def test_streaming_decoding_with_variable_initial(config, state, expected):
chunk_frames, left_context, initial_chunk = config
n_frames, put_req, finished = state

tm = _tm(chunk_frames=chunk_frames, left_context=left_context, initial_chunk=initial_chunk)
payload = _call(tm, "r", n_frames=n_frames, put_req=put_req, finished=finished)

if expected is None:
assert payload is None
else:
exp_ctx, exp_window = expected
assert payload is not None
assert payload["code_predictor_codes"][0] == exp_ctx
assert len(payload["code_predictor_codes"]) == 1 + _Q * exp_window
4 changes: 4 additions & 0 deletions vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ runtime:
# Align with Omni: small chunks with sufficient context overlap.
codec_chunk_frames: 25
codec_left_context_frames: 25
# First chunk size for reduced TTFA (0 = disabled).
# When > 0, emits small chunks every N frames during warmup,
# then switches to codec_chunk_frames cadence.
initial_codec_chunk_frames: 0

edges:
- from: 0
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ runtime:
# Align with Omni: small chunks with sufficient context overlap.
codec_chunk_frames: 25
codec_left_context_frames: 25
# First chunk size for reduced TTFA (0 = disabled).
# When > 0, emits small chunks every N frames during warmup,
# then switches to codec_chunk_frames cadence.
initial_codec_chunk_frames: 0

edges:
- from: 0
Expand Down
51 changes: 41 additions & 10 deletions vllm_omni/model_executor/stage_input_processors/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any

import torch
from vllm.logger import init_logger

logger = init_logger(__name__)


def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None:
Expand Down Expand Up @@ -42,11 +45,20 @@ def talker2code2wav_async_chunk(
cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {}
chunk_size = int(cfg.get("codec_chunk_frames", 25))
left_context_size = int(cfg.get("codec_left_context_frames", 25))
if chunk_size <= 0 or left_context_size < 0:
initial_chunk_size = int(cfg.get("initial_codec_chunk_frames", 0))
if chunk_size <= 0 or left_context_size < 0 or initial_chunk_size < 0:
raise ValueError(
f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, "
f"codec_left_context_frames={left_context_size}"
f"codec_left_context_frames={left_context_size}, "
f"initial_codec_chunk_frames={initial_chunk_size}"
)
if initial_chunk_size > chunk_size:
logger.warning(
"initial_codec_chunk_frames=%d > codec_chunk_frames=%d, clamping to codec_chunk_frames.",
initial_chunk_size,
chunk_size,
)
initial_chunk_size = chunk_size
length = len(transfer_manager.code_prompt_token_ids[request_id])

# Avoid emitting empty chunks during normal streaming. If the request is
Expand All @@ -59,15 +71,34 @@ def talker2code2wav_async_chunk(
}
return None

chunk_length = length % chunk_size

if chunk_length != 0 and not finished:
return None
in_warmup = initial_chunk_size > 0 and length <= chunk_size

context_length = chunk_length if chunk_length != 0 else chunk_size
end_index = min(length, left_context_size + context_length)
ctx_frames = max(0, int(end_index - context_length))
window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:]
if in_warmup:
# Warmup phase: emit every initial_chunk_size frames with full context.
# Track frames already delivered using put_req_chunk counter.
already_sent = transfer_manager.put_req_chunk[request_id] * initial_chunk_size
pending = length - already_sent
if pending <= 0:
return None
if pending < initial_chunk_size and not finished:
return None
context_length = min(pending, initial_chunk_size)
end_index = length
ctx_frames = max(0, length - context_length)
window_frames = transfer_manager.code_prompt_token_ids[request_id][:length]
else:
# Normal phase: standard chunk_size cadence with left_context sliding window.
# Offset by warmup coverage (static from config) so normal starts
# from where warmup left off.
warmup_coverage = (chunk_size // initial_chunk_size) * initial_chunk_size if initial_chunk_size > 0 else 0
adjusted = length - warmup_coverage
chunk_length = adjusted % chunk_size
if chunk_length != 0 and not finished:
return None
context_length = chunk_length if chunk_length != 0 else chunk_size
end_index = min(length, left_context_size + context_length)
ctx_frames = max(0, int(end_index - context_length))
window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:]

# Pack context + chunk into codebook-major flat codes for adapter.
code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist()
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ runtime:
# Align with Omni: small chunks with sufficient context overlap.
codec_chunk_frames: 25
codec_left_context_frames: 25
# First chunk size for reduced TTFA (0 = disabled).
# When > 0, emits small chunks every N frames during warmup,
# then switches to codec_chunk_frames cadence.
initial_codec_chunk_frames: 0

edges:
- from: 0
Expand Down