Skip to content
14 changes: 10 additions & 4 deletions examples/offline_inference/qwen3_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,16 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
task_type = "CustomVoice"
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
if use_batch_sample:
texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."]
instructs = ["", "Very happy."]
languages = ["Chinese", "English"]
speakers = ["Vivian", "Ryan"]
texts = [
"其实我真的有发现,我是一个特别善于观察别人情绪的人。",
"She said she would be here by noon.",
"I like you very much.",
"Really, you do?",
"Yes, absolutely.",
]
instructs = ["", "Very happy.", "Very happy.", "Very happy.", "Very happy."]
languages = ["Chinese", "English", "English", "English", "English"]
speakers = ["Vivian", "Ryan", "Ryan", "Ryan", "Ryan"]
inputs = []
for text, instruct, language, speaker in zip(texts, instructs, languages, speakers):
additional_information = {
Expand Down
93 changes: 67 additions & 26 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@
# Store vllm_config for potential future use
self.vllm_config = vllm_config

@staticmethod
def extract_val(d, key, default):
val = d.get(key, default)
if isinstance(val, list):
return val[0] if len(val) > 0 else default
return val
# Enable CUDA Graph for decoder
self._enable_decoder_cudagraph()

Expand Down Expand Up @@ -136,22 +142,51 @@
**kwargs: Any,
) -> OmniOutput:
"""
Forward pass for TTS generation model.

Args:
input_ids: Input token IDs (required for TTS generation)
positions: Position IDs (not used for TTS, but required by runner)
intermediate_tensors: Intermediate tensors for pipeline parallelism (not used)
inputs_embeds: Input embeddings (not used for TTS, but required by runner)
**kwargs: Additional arguments including task_type, sampling_metadata, etc.

Returns:
OmniOutput: Contains multimodal outputs with audio tensors
Forward pass for TTS generation model (Patched for batched inference).
"""
runtime_info_list = kwargs.get("runtime_additional_information", [{}])
if not isinstance(runtime_info_list, list):
runtime_info_list = [runtime_info_list]

# Initialize lists to accumulate batched inputs
texts = []
task_types = []
speakers = []
languages = []
instructs = []
merged_kwargs = {}

# Keys that the underlying model natively supports as lists for batched inference
batched_keys = {"ref_audio", "ref_text", "x_vector_only_mode", "voice_clone_prompt"}

for req_info in runtime_info_list:
texts.append(self.extract_val(req_info, "text", ""))
task_types.append(self.extract_val(req_info, "task_type", self.task_type))
speakers.append(self.extract_val(req_info, "speaker", "uncle_fu"))
languages.append(self.extract_val(req_info, "language", "Auto"))
instructs.append(self.extract_val(req_info, "instruct", ""))

for k, v in req_info.items():
if k not in ["text", "task_type", "speaker", "language", "instruct"]:
# Extract single value from list if wrapped
val = v[0] if isinstance(v, list) and len(v) > 0 else v

if k in batched_keys:
# Accumulate as list for batched generation
if k not in merged_kwargs:
merged_kwargs[k] = []
merged_kwargs[k].append(val)
else:
# For scalar params (e.g. max_new_tokens), take from the first request
if k not in merged_kwargs:
merged_kwargs[k] = val

# During profile/warmup runs, texts are empty.
if all(not t for t in texts):

# Extract additional parameters from kwargs that the generation methods expect

runtime_additional_information = kwargs.get("runtime_additional_information", [{}])

Check failure on line 189 in vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (invalid-syntax)

vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py:189:9: invalid-syntax: Expected an indented block after `if` statement
if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0:
runtime_additional_information = runtime_additional_information[0]
text = runtime_additional_information.pop("text", [""])[0]
Expand All @@ -170,19 +205,22 @@
# cannot converge from degenerate dummy inputs.
if not text:
logger.info("Profile run detected (empty text). Capping max_new_tokens to 2.")
runtime_additional_information["max_new_tokens"] = 2
merged_kwargs["max_new_tokens"] = 2

# Call the appropriate generation method based on task_type
# Assume uniform task type across the batch
if len(set(task_types)) > 1:
raise ValueError(f"Mixed task types not supported: {set(task_types)}")
task_type = task_types[0]

# Call the appropriate generation method based on task_type, passing lists
if task_type == "CustomVoice":
result = self.model.generate_custom_voice(
text, speaker=speaker, language=language, instruct=instruct, **runtime_additional_information
texts, speaker=speakers, language=languages, instruct=instructs, **merged_kwargs
)
elif task_type == "VoiceDesign":
result = self.model.generate_voice_design(
text, instruct=instruct, language=language, **runtime_additional_information
)
result = self.model.generate_voice_design(texts, instruct=instructs, language=languages, **merged_kwargs)
elif task_type == "Base":
result = self.model.generate_voice_clone(text, language=language, **runtime_additional_information)
result = self.model.generate_voice_clone(texts, language=languages, **merged_kwargs)
else:
raise ValueError(f"Invalid task type: {task_type}")

Expand All @@ -201,17 +239,20 @@
# Handle tuple format: (audio_tensors, sample_rate)
if isinstance(model_outputs, tuple) and len(model_outputs) == 2:
audio_tensors, sr = model_outputs
# audio_tensors is a list of numpy arrays, convert first one to tensor if needed
# audio_tensors is a list of numpy arrays, convert ALL to tensors
if isinstance(audio_tensors, list) and len(audio_tensors) > 0:
# Convert numpy array to tensor if needed
audio_tensor = audio_tensors[0]
if isinstance(audio_tensor, np.ndarray):
audio_tensor = torch.from_numpy(audio_tensor).float()
elif not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor, dtype=torch.float32)
audio_tensor_list = []
for audio_tensor in audio_tensors:
if isinstance(audio_tensor, np.ndarray):
audio_tensor_list.append(torch.from_numpy(audio_tensor).float())
elif not isinstance(audio_tensor, torch.Tensor):
audio_tensor_list.append(torch.tensor(audio_tensor, dtype=torch.float32))
else:
audio_tensor_list.append(audio_tensor)

return OmniOutput(
text_hidden_states=None,
multimodal_outputs={"model_outputs": audio_tensor, "sr": torch.tensor(sr, dtype=torch.int)},
multimodal_outputs={"model_outputs": audio_tensor_list, "sr": torch.tensor(sr, dtype=torch.int)},
)

# If it's already a tensor, wrap it
Expand Down
6 changes: 3 additions & 3 deletions vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ stage_args:
stage_type: llm
runtime:
devices: "0"
max_batch_size: 1
max_batch_size: 10
engine_args:
model_stage: qwen3_tts
model_arch: Qwen3TTSTalkerForConditionalGeneration
Expand Down Expand Up @@ -52,8 +52,8 @@ stage_args:
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio
gpu_memory_utilization: 0.2
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.5
distributed_executor_backend: "mp"
# Must be divisible by num_code_groups and cover (left_context + chunk).
max_num_batched_tokens: 8192
Expand Down
Loading