Skip to content

Commit 2e6895f

Browse files
JuanPZuluagahsliuustc0106david6666666
authored
[Bugfix][Qwen3TTS] Load speaker_id/voices from model configuration (vllm-project#1079)
Signed-off-by: pablo <juanz9312@gmail.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>
1 parent 07f4c3f commit 2e6895f

File tree

4 files changed

+89
-15
lines changed

4 files changed

+89
-15
lines changed

examples/online_serving/qwen3_tts/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ curl -X POST http://localhost:8000/v1/audio/speech \
7575
"voice": "Vivian",
7676
"instructions": "Speak with great enthusiasm"
7777
}' --output excited.wav
78+
79+
# List available voices in CustomVoice models
80+
curl http://localhost:8000/v1/audio/voices
7881
```
7982

8083
## API Reference

tests/entrypoints/openai_api/test_serving_speech.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ async def awaitable_patched_create_speech(*args, **kwargs):
201201
app = FastAPI()
202202
app.add_api_route("/v1/audio/speech", speech_server.create_speech, methods=["POST"], response_model=None)
203203

204+
# Add list_voices endpoint
205+
async def list_voices():
206+
speakers = sorted(speech_server.supported_speakers) if speech_server.supported_speakers else []
207+
return {"voices": speakers}
208+
209+
app.add_api_route("/v1/audio/voices", list_voices, methods=["GET"])
210+
204211
return app
205212

206213

@@ -268,6 +275,11 @@ def test_speed_parameter_is_used(self, mock_create_audio, test_app):
268275
assert isinstance(audio_obj, CreateAudio)
269276
assert audio_obj.speed == 2.5
270277

278+
def test_list_voices_endpoint(self, client):
279+
response = client.get("/v1/audio/voices")
280+
assert response.status_code == 200
281+
assert "voices" in response.json()
282+
271283

272284
class TestTTSMethods:
273285
"""Unit tests for TTS validation and parameter building."""
@@ -311,9 +323,9 @@ def test_validate_tts_request_basic(self, speech_server):
311323
req = OpenAICreateSpeechRequest(input="Hello", language="InvalidLang")
312324
assert "Invalid language" in speech_server._validate_tts_request(req)
313325

314-
# Invalid speaker
326+
# When no speakers loaded, any voice is accepted (unconstrained)
315327
req = OpenAICreateSpeechRequest(input="Hello", voice="Invalid")
316-
assert "Invalid speaker" in speech_server._validate_tts_request(req)
328+
assert speech_server._validate_tts_request(req) is None
317329

318330
# Valid request
319331
req = OpenAICreateSpeechRequest(input="Hello", voice="Vivian")
@@ -342,3 +354,26 @@ def test_build_tts_params(self, speech_server):
342354
assert params["speaker"] == ["Ryan"]
343355
assert params["language"] == ["English"]
344356
assert params["task_type"] == ["CustomVoice"]
357+
358+
def test_load_supported_speakers(self):
359+
"""Test _load_supported_speakers."""
360+
mock_engine_client = MagicMock()
361+
mock_engine_client.errored = False
362+
mock_engine_client.stage_list = None
363+
364+
# Mock talker_config with mixed-case speaker names
365+
mock_talker_config = MagicMock()
366+
mock_talker_config.spk_id = {"Ryan": 0, "Vivian": 1, "Aiden": 2}
367+
mock_engine_client.model_config.hf_config.talker_config = mock_talker_config
368+
369+
mock_models = MagicMock()
370+
mock_models.is_base_model.return_value = True
371+
372+
server = OmniOpenAIServingSpeech(
373+
engine_client=mock_engine_client,
374+
models=mock_models,
375+
request_logger=MagicMock(),
376+
)
377+
378+
# Verify speakers are normalized to lowercase
379+
assert server.supported_speakers == {"ryan", "vivian", "aiden"}

vllm_omni/entrypoints/openai/api_server.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,25 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request
762762
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e
763763

764764

765+
@router.get(
766+
"/v1/audio/voices",
767+
responses={
768+
HTTPStatus.OK.value: {"model": dict},
769+
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
770+
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
771+
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
772+
},
773+
)
774+
async def list_voices(raw_request: Request):
775+
"""List available TTS voices/speakers from the loaded model."""
776+
handler = Omnispeech(raw_request)
777+
if handler is None:
778+
return base(raw_request).create_error_response(message="The model does not support Speech API")
779+
780+
speakers = sorted(handler.supported_speakers) if handler.supported_speakers else []
781+
return JSONResponse(content={"voices": speakers})
782+
783+
765784
# Health and Model endpoints for diffusion mode
766785

767786

vllm_omni/entrypoints/openai/serving_speech.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,6 @@
1919

2020
# TTS Configuration (currently supports Qwen3-TTS)
2121
_TTS_MODEL_STAGES: set[str] = {"qwen3_tts"}
22-
_TTS_SPEAKERS: set[str] = {
23-
"Vivian",
24-
"Serena",
25-
"Uncle_Fu",
26-
"Dylan",
27-
"Eric",
28-
"Ryan",
29-
"Aiden",
30-
"Ono_Anna",
31-
"Sohee",
32-
}
3322
_TTS_LANGUAGES: set[str] = {
3423
"Auto",
3524
"Chinese",
@@ -49,6 +38,30 @@
4938

5039

5140
class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin):
41+
def __init__(self, *args, **kwargs):
42+
super().__init__(*args, **kwargs)
43+
# Load supported speakers
44+
self.supported_speakers = self._load_supported_speakers()
45+
logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}")
46+
47+
def _load_supported_speakers(self) -> set[str]:
48+
"""Load supported speakers (case-insensitive) from the model configuration."""
49+
try:
50+
talker_config = self.engine_client.model_config.hf_config.talker_config
51+
52+
# Check for speakers in either spk_id or speaker_id
53+
for attr_name in ["spk_id", "speaker_id"]:
54+
speakers_dict = getattr(talker_config, attr_name, None)
55+
if speakers_dict and isinstance(speakers_dict, dict):
56+
# Normalize to lowercase for case-insensitive matching
57+
return {speaker.lower() for speaker in speakers_dict.keys()}
58+
59+
logger.warning("No speakers found in talker_config (checked spk_id and speaker_id)")
60+
except Exception as e:
61+
logger.warning(f"Could not load speakers from model config: {e}")
62+
63+
return set()
64+
5265
def _is_tts_model(self) -> bool:
5366
"""Check if the current model is a supported TTS model."""
5467
stage_list = getattr(self.engine_client, "stage_list", None)
@@ -63,6 +76,10 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non
6376
"""Validate TTS request parameters. Returns error message or None."""
6477
task_type = request.task_type or "CustomVoice"
6578

79+
# Normalize voice to lowercase for case-insensitive matching
80+
if request.voice is not None:
81+
request.voice = request.voice.lower()
82+
6683
# Validate input is not empty
6784
if not request.input or not request.input.strip():
6885
return "Input text cannot be empty"
@@ -73,8 +90,8 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non
7390

7491
# Validate speaker for CustomVoice task
7592
if task_type == "CustomVoice" and request.voice is not None:
76-
if request.voice not in _TTS_SPEAKERS:
77-
return f"Invalid speaker '{request.voice}'. Supported: {', '.join(sorted(_TTS_SPEAKERS))}"
93+
if self.supported_speakers and request.voice not in self.supported_speakers:
94+
return f"Invalid speaker '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}"
7895

7996
# Validate Base task requirements
8097
if task_type == "Base":

0 commit comments

Comments
 (0)