Skip to content

Commit 9cf6cab

Browse files
committed
Online batched inference also works
1 parent 6dd5ad2 commit 9cf6cab

File tree

6 files changed

+1252
-4
lines changed

6 files changed

+1252
-4
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,3 @@ tmp_test
245245

246246
# output files
247247
*.wav
248-
examples/offline_inference/qwen3_tts/test.py
249-
examples/online_serving/qwen3_tts/Untitled.ipynb

READMEmy.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@ cd ../vllm-omni
1616
uv pip install -e .
1717
```
1818

19-
19+
# Examples
2020

2121
```bash
2222
# edit /lustre/users/rkoshkin/vllm-omni/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml AS NECESSARY
2323
cd examples/online_serving/qwen3_tts
2424
./run_server.sh Base
2525
```
26+
27+
More online and offline inference examples are in
28+
29+
`/lustre/users/rkoshkin/vllm-omni/examples/online_serving/qwen3_tts/Examples.ipynb`
30+
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
from typing import NamedTuple
3+
import soundfile as sf
4+
from typing import List
5+
from datasets import Dataset, DatasetDict
6+
7+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
8+
9+
from vllm import SamplingParams
10+
from vllm.utils.argparse_utils import FlexibleArgumentParser
11+
from vllm_omni import Omni
12+
13+
class QueryResult(NamedTuple):
14+
"""Container for a prepared Omni request."""
15+
16+
inputs: dict
17+
model_name: str
18+
19+
# new
20+
def get_base_query(
21+
ref_audios: List[str],
22+
ref_texts: List[str],
23+
target_texts: List[str],
24+
target_langs: List[str],
25+
):
26+
27+
inputs = []
28+
for target_text, target_lang, ref_audio, ref_text in zip(
29+
target_texts,
30+
target_langs,
31+
ref_audios,
32+
ref_texts,
33+
):
34+
prompt = f"<|im_start|>assistant\n{target_text}<|im_end|>\n<|im_start|>assistant\n"
35+
print(prompt)
36+
inputs.append(
37+
{
38+
"prompt": prompt,
39+
"additional_information": {
40+
"task_type": ["Base"],
41+
"ref_audio": [ref_audio],
42+
"ref_text": [ref_text],
43+
"text": [target_text],
44+
"language": [target_lang],
45+
"x_vector_only_mode": [False],
46+
"max_new_tokens": [8192],
47+
},
48+
}
49+
)
50+
51+
return QueryResult(
52+
inputs=inputs,
53+
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
54+
)
55+
56+
def main():
57+
58+
omni = Omni(
59+
model="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
60+
stage_configs_path="/lustre/users/rkoshkin/vllm-omni/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml",
61+
log_stats=True,
62+
stage_ibnit_timeout=300,
63+
)
64+
65+
ds = DatasetDict.load_from_disk("/lustre/users/rkoshkin/s2st/data/s2s/podcast_crawl-enru-dd-full-s2s+sid/")['train']
66+
target_texts = [ds[0]['src_sent'][i] for i in range(80)]
67+
ref_audios = ["/lustre/users/rkoshkin/s2st/assets/ru_ref.sample.wav"] * len(target_texts)
68+
ref_texts = [
69+
"Привет! С вами Программный Комитет - шоу подкаст-студии Термин-Вокс и IT-конфереции Стачка. Меня зовут Сергей Пихин.В этом подкасте мы обсуждаем главные тренды в IT-индустрии и в смежных областях. Помогают нам в этом топовые эксперты, которые делятся своими знаниями и экспертизой.",
70+
] * len(target_texts)
71+
target_langs = ["English"] * len(target_texts)
72+
73+
74+
query_result = get_base_query(ref_audios, ref_texts, target_texts, target_langs)
75+
76+
sampling_params = SamplingParams(
77+
temperature=0.9,
78+
top_p=1.0,
79+
top_k=50,
80+
max_tokens=8192,
81+
seed=42,
82+
detokenize=False,
83+
repetition_penalty=1.05,
84+
)
85+
86+
sampling_params_list = [
87+
sampling_params,
88+
]
89+
90+
output_dir = "/lustre/users/rkoshkin/vllm-omni/examples/offline_inference/qwen3_tts/output"
91+
os.makedirs(output_dir, exist_ok=True)
92+
93+
omni_generator = omni.generate(query_result.inputs, sampling_params_list)
94+
for stage_outputs in omni_generator:
95+
for output in stage_outputs.request_output:
96+
request_id = output.request_id
97+
audio_tensor = output.outputs[0].multimodal_output["audio"].clone()
98+
print(f"audio_tensor: {audio_tensor.shape}")
99+
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
100+
audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
101+
# Convert to numpy array and ensure correct format
102+
audio_numpy = audio_tensor.float().detach().cpu().numpy()
103+
104+
# Ensure audio is 1D (flatten if needed)
105+
if audio_numpy.ndim > 1:
106+
audio_numpy = audio_numpy.flatten()
107+
108+
# Save audio file with explicit WAV format
109+
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
110+
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
111+
112+
if __name__ == "__main__":
113+
main()
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
from typing import NamedTuple # noqa: UP035
3+
4+
import soundfile as sf
5+
6+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
7+
8+
from vllm import SamplingParams
9+
10+
from vllm_omni import Omni
11+
12+
13+
class QueryResult(NamedTuple):
14+
"""Container for a prepared Omni request."""
15+
16+
inputs: dict
17+
model_name: str
18+
19+
# new
20+
def get_base_query(
21+
ref_audios: list[str],
22+
ref_texts: list[str],
23+
target_texts: list[str],
24+
target_langs: list[str],
25+
):
26+
27+
inputs = []
28+
for target_text, target_lang, ref_audio, ref_text in zip(
29+
target_texts,
30+
target_langs,
31+
ref_audios,
32+
ref_texts,
33+
):
34+
prompt = f"<|im_start|>assistant\n{target_text}<|im_end|>\n<|im_start|>assistant\n"
35+
print(prompt)
36+
inputs.append(
37+
{
38+
"prompt": prompt,
39+
"additional_information": {
40+
"task_type": ["Base"],
41+
"ref_audio": [ref_audio],
42+
"ref_text": [ref_text],
43+
"text": [target_text],
44+
"language": [target_lang],
45+
"x_vector_only_mode": [False],
46+
"max_new_tokens": [8192],
47+
},
48+
}
49+
)
50+
51+
return QueryResult(
52+
inputs=inputs,
53+
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
54+
)
55+
56+
def main():
57+
58+
omni = Omni(
59+
model="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
60+
stage_configs_path="/lustre/users/rkoshkin/vllm-omni/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml",
61+
log_stats=True,
62+
stage_ibnit_timeout=300,
63+
)
64+
65+
66+
target_texts = [
67+
'Welcome to another episode of Out of the Pods.',
68+
"I'm Deep T. And I'm Natalie.",
69+
'And happy Wednesday.',
70+
'You know, we said last week that this episode is going to be about our recap of Perfect Match Season 2, Episodes 1 through 6, which we will get into.',
71+
'Lots of thoughts.',
72+
'Actually, almost no thoughts because...',
73+
'This is not a great season.',
74+
"It's just not off to a good start.",
75+
'I feel like I lost some brain cells watching it.',
76+
'Oh, 100%.'
77+
]
78+
# ref_audios = ["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"] * len(target_texts)
79+
# ref_texts = [
80+
# "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you.",
81+
# ] * len(target_texts)
82+
ref_audios = ["/lustre/users/rkoshkin/s2st/bak/trump.mp3"] * len(target_texts)
83+
ref_texts = ["because of it. Look, we were ripped off by almost every country in the world. If you look at the surpluses, almost every country in the world that did business with us, our people were stupid. And I blame presidents for it because they're ultimately the leader. Uh we were being ripped off by almost every single country in the world had massive some massive surpluses. China had hundreds of billions of dollars in surpluses with the United States. They rebuilt China. They rebuilt the army. We built China's army by allowing that to happen. I have a great relationship with President Xi, but he respects our country now. Now, what we've done, I charged China a 20% tariff as a penalty for sending fentinol in. And that was 20 times more than they could make by selling fentanol."] * len(target_texts)
84+
target_langs = ["English"] * len(target_texts)
85+
86+
87+
query_result = get_base_query(ref_audios, ref_texts, target_texts, target_langs)
88+
89+
sampling_params = SamplingParams(
90+
temperature=0.9,
91+
top_p=1.0,
92+
top_k=50,
93+
max_tokens=8192,
94+
seed=42,
95+
detokenize=False,
96+
repetition_penalty=1.05,
97+
)
98+
99+
sampling_params_list = [
100+
sampling_params,
101+
]
102+
103+
output_dir = "/lustre/users/rkoshkin/vllm-omni/examples/offline_inference/qwen3_tts/output"
104+
os.makedirs(output_dir, exist_ok=True)
105+
106+
omni_generator = omni.generate(query_result.inputs, sampling_params_list)
107+
for stage_outputs in omni_generator:
108+
for output in stage_outputs.request_output:
109+
request_id = output.request_id
110+
audio_tensor = output.outputs[0].multimodal_output["audio"].clone()
111+
print(f"audio_tensor: {audio_tensor.shape}")
112+
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
113+
audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
114+
# Convert to numpy array and ensure correct format
115+
audio_numpy = audio_tensor.float().detach().cpu().numpy()
116+
117+
# Ensure audio is 1D (flatten if needed)
118+
if audio_numpy.ndim > 1:
119+
audio_numpy = audio_numpy.flatten()
120+
121+
# Save audio file with explicit WAV format
122+
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
123+
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
124+
125+
if __name__ == "__main__":
126+
main()

0 commit comments

Comments
 (0)