Skip to content

Commit daceac5

Browse files
authored
[Frontend] Generalize v1/audio/transcriptions endpoint (#20179)
Signed-off-by: NickLucche <[email protected]>
1 parent 8615d97 commit daceac5

File tree

3 files changed

+154
-128
lines changed

3 files changed

+154
-128
lines changed

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 14 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
2525
from vllm.inputs.data import PromptType
2626
from vllm.logger import init_logger
27+
from vllm.model_executor.model_loader.utils import get_model_architecture
2728
from vllm.outputs import RequestOutput
2829
from vllm.transformers_utils.processor import cached_get_processor
2930
from vllm.utils import PlaceholderModule
@@ -38,118 +39,10 @@
3839

3940
logger = init_logger(__name__)
4041

41-
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
42-
# TODO these configs should live somewhere with the model so we can support
43-
# additional ones
44-
45-
ISO639_1_SUPPORTED_LANGS = {
46-
"af": "Afrikaans",
47-
"ar": "Arabic",
48-
"hy": "Armenian",
49-
"az": "Azerbaijani",
50-
"be": "Belarusian",
51-
"bs": "Bosnian",
52-
"bg": "Bulgarian",
53-
"ca": "Catalan",
54-
"zh": "Chinese",
55-
"hr": "Croatian",
56-
"cs": "Czech",
57-
"da": "Danish",
58-
"nl": "Dutch",
59-
"en": "English",
60-
"et": "Estonian",
61-
"fi": "Finnish",
62-
"fr": "French",
63-
"gl": "Galician",
64-
"de": "German",
65-
"el": "Greek",
66-
"he": "Hebrew",
67-
"hi": "Hindi",
68-
"hu": "Hungarian",
69-
"is": "Icelandic",
70-
"id": "Indonesian",
71-
"it": "Italian",
72-
"ja": "Japanese",
73-
"kn": "Kannada",
74-
"kk": "Kazakh",
75-
"ko": "Korean",
76-
"lv": "Latvian",
77-
"lt": "Lithuanian",
78-
"mk": "Macedonian",
79-
"ms": "Malay",
80-
"mr": "Marathi",
81-
"mi": "Maori",
82-
"ne": "Nepali",
83-
"no": "Norwegian",
84-
"fa": "Persian",
85-
"pl": "Polish",
86-
"pt": "Portuguese",
87-
"ro": "Romanian",
88-
"ru": "Russian",
89-
"sr": "Serbian",
90-
"sk": "Slovak",
91-
"sl": "Slovenian",
92-
"es": "Spanish",
93-
"sw": "Swahili",
94-
"sv": "Swedish",
95-
"tl": "Tagalog",
96-
"ta": "Tamil",
97-
"th": "Thai",
98-
"tr": "Turkish",
99-
"uk": "Ukrainian",
100-
"ur": "Urdu",
101-
"vi": "Vietnamese",
102-
"cy": "Welsh"
103-
}
104-
ISO639_1_OTHER_LANGS = {
105-
"lo": "Lao",
106-
"jw": "Javanese",
107-
"tk": "Turkmen",
108-
"yi": "Yiddish",
109-
"so": "Somali",
110-
"bn": "Bengali",
111-
"nn": "Norwegian Nynorsk",
112-
"si": "Sinhala",
113-
"yo": "Yoruba",
114-
"sa": "Sanskrit",
115-
"mi": "Māori",
116-
"fo": "Faroese", # codespell:ignore
117-
"mt": "Maltese",
118-
"tg": "Tajik",
119-
"mg": "Malagasy",
120-
"haw": "Hawaiian",
121-
"km": "Khmer",
122-
"br": "Breton",
123-
"ps": "Pashto",
124-
"ln": "Lingala",
125-
"la": "Latin",
126-
"ml": "Malayalam",
127-
"sq": "Albanian",
128-
"su": "Sundanese",
129-
"eu": "Basque",
130-
"ka": "Georgian",
131-
"uz": "Uzbek",
132-
"sn": "Shona",
133-
"ht": "Haitian",
134-
"as": "Assamese",
135-
"mn": "Mongolian",
136-
"te": "Telugu",
137-
"pa": "Panjabi",
138-
"tt": "Tatar",
139-
"gu": "Gujarati",
140-
"oc": "Occitan",
141-
"ha": "Hausa",
142-
"ba": "Bashkir",
143-
"my": "Burmese",
144-
"sd": "Sindhi",
145-
"am": "Amharic",
146-
"lb": "Luxembourgish",
147-
"bo": "Tibetan"
148-
}
149-
15042
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
15143
# TODO configurable
15244
MAX_AUDIO_CLIP_FILESIZE_MB = 25
45+
MAX_AUDIO_CLIP_SECONDS = 30
15346
OVERLAP_CHUNK_SECOND = 1
15447
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
15548

@@ -177,10 +70,13 @@ def __init__(
17770
self.default_sampling_params = (
17871
self.model_config.get_diff_sampling_param())
17972
processor = cached_get_processor(model_config.model)
180-
self.max_audio_clip_s = processor.feature_extractor.chunk_length
73+
self.max_audio_clip_s = processor.feature_extractor.chunk_length \
74+
if hasattr(processor.feature_extractor, 'chunk_length') \
75+
else MAX_AUDIO_CLIP_SECONDS
18176
self.model_sr = processor.feature_extractor.sampling_rate
18277
self.hop_length = processor.feature_extractor.hop_length
18378
self.task_type = task_type
79+
self.model_cls, _ = get_model_architecture(model_config)
18480

18581
if self.default_sampling_params:
18682
logger.info(
@@ -196,21 +92,8 @@ async def _preprocess_speech_to_text(
19692
# TODO language should be optional and can be guessed.
19793
# For now we default to en. See
19894
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
199-
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
200-
if request.language:
201-
if request.language in ISO639_1_SUPPORTED_LANGS:
202-
pass
203-
elif request.language in ISO639_1_OTHER_LANGS:
204-
logger.warning(
205-
"The selected language %s has limited accuracy with"
206-
" reported WER>=0.5. Results may be less accurate "
207-
"for this choice.", request.language)
208-
else:
209-
raise ValueError(
210-
f"Unsupported language: {request.language}."
211-
"Language should be one of:" +
212-
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
213-
f"or {list(ISO639_1_OTHER_LANGS.values())}")
95+
lang = request.language or "en"
96+
self.model_cls.validate_language(lang) # type: ignore[attr-defined]
21497

21598
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
21699
raise ValueError("Maximum file size exceeded.")
@@ -221,7 +104,9 @@ async def _preprocess_speech_to_text(
221104
y, sr = librosa.load(bytes_, sr=self.model_sr)
222105

223106
duration = librosa.get_duration(y=y, sr=sr)
224-
chunks = [y] if duration < 30 else self._split_audio(y, int(sr))
107+
chunks = [y
108+
] if duration < self.max_audio_clip_s else self._split_audio(
109+
y, int(sr))
225110
prompts = []
226111
for chunk in chunks:
227112
prompt = {
@@ -232,8 +117,9 @@ async def _preprocess_speech_to_text(
232117
},
233118
},
234119
"decoder_prompt":
235-
(f"<|startoftranscript|>{lang_token}"
236-
f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
120+
self.model_cls.
121+
get_decoder_prompt( # type: ignore[attr-defined]
122+
lang, self.task_type, request.prompt)
237123
}
238124
prompts.append(cast(PromptType, prompt))
239125
return prompts, duration

vllm/model_executor/models/interfaces.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,17 @@ class SupportsTranscription(Protocol):
599599

600600
supports_transcription: ClassVar[Literal[True]] = True
601601

602+
@classmethod
603+
def get_decoder_prompt(cls, language: str, task_type: str,
604+
prompt: str) -> str:
605+
"""Get the decoder prompt for the ASR model."""
606+
...
607+
608+
@classmethod
609+
def validate_language(cls, language: str) -> bool:
610+
"""Check if the model supports a specific ISO639_1 language."""
611+
...
612+
602613

603614
@overload
604615
def supports_transcription(

vllm/model_executor/models/whisper.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,113 @@
4141

4242
logger = init_logger(__name__)
4343

44+
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
45+
46+
ISO639_1_SUPPORTED_LANGS = {
47+
"af": "Afrikaans",
48+
"ar": "Arabic",
49+
"hy": "Armenian",
50+
"az": "Azerbaijani",
51+
"be": "Belarusian",
52+
"bs": "Bosnian",
53+
"bg": "Bulgarian",
54+
"ca": "Catalan",
55+
"zh": "Chinese",
56+
"hr": "Croatian",
57+
"cs": "Czech",
58+
"da": "Danish",
59+
"nl": "Dutch",
60+
"en": "English",
61+
"et": "Estonian",
62+
"fi": "Finnish",
63+
"fr": "French",
64+
"gl": "Galician",
65+
"de": "German",
66+
"el": "Greek",
67+
"he": "Hebrew",
68+
"hi": "Hindi",
69+
"hu": "Hungarian",
70+
"is": "Icelandic",
71+
"id": "Indonesian",
72+
"it": "Italian",
73+
"ja": "Japanese",
74+
"kn": "Kannada",
75+
"kk": "Kazakh",
76+
"ko": "Korean",
77+
"lv": "Latvian",
78+
"lt": "Lithuanian",
79+
"mk": "Macedonian",
80+
"ms": "Malay",
81+
"mr": "Marathi",
82+
"mi": "Maori",
83+
"ne": "Nepali",
84+
"no": "Norwegian",
85+
"fa": "Persian",
86+
"pl": "Polish",
87+
"pt": "Portuguese",
88+
"ro": "Romanian",
89+
"ru": "Russian",
90+
"sr": "Serbian",
91+
"sk": "Slovak",
92+
"sl": "Slovenian",
93+
"es": "Spanish",
94+
"sw": "Swahili",
95+
"sv": "Swedish",
96+
"tl": "Tagalog",
97+
"ta": "Tamil",
98+
"th": "Thai",
99+
"tr": "Turkish",
100+
"uk": "Ukrainian",
101+
"ur": "Urdu",
102+
"vi": "Vietnamese",
103+
"cy": "Welsh"
104+
}
105+
ISO639_1_OTHER_LANGS = {
106+
"lo": "Lao",
107+
"jw": "Javanese",
108+
"tk": "Turkmen",
109+
"yi": "Yiddish",
110+
"so": "Somali",
111+
"bn": "Bengali",
112+
"nn": "Norwegian Nynorsk",
113+
"si": "Sinhala",
114+
"yo": "Yoruba",
115+
"sa": "Sanskrit",
116+
"mi": "Māori",
117+
"fo": "Faroese", # codespell:ignore
118+
"mt": "Maltese",
119+
"tg": "Tajik",
120+
"mg": "Malagasy",
121+
"haw": "Hawaiian",
122+
"km": "Khmer",
123+
"br": "Breton",
124+
"ps": "Pashto",
125+
"ln": "Lingala",
126+
"la": "Latin",
127+
"ml": "Malayalam",
128+
"sq": "Albanian",
129+
"su": "Sundanese",
130+
"eu": "Basque",
131+
"ka": "Georgian",
132+
"uz": "Uzbek",
133+
"sn": "Shona",
134+
"ht": "Haitian",
135+
"as": "Assamese",
136+
"mn": "Mongolian",
137+
"te": "Telugu",
138+
"pa": "Panjabi",
139+
"tt": "Tatar",
140+
"gu": "Gujarati",
141+
"oc": "Occitan",
142+
"ha": "Hausa",
143+
"ba": "Bashkir",
144+
"my": "Burmese",
145+
"sd": "Sindhi",
146+
"am": "Amharic",
147+
"lb": "Luxembourgish",
148+
"bo": "Tibetan"
149+
}
150+
44151

45152
class WhisperAudioInputs(TypedDict):
46153
input_features: NestedTensors
@@ -731,6 +838,28 @@ def load_weights(self, weights: Iterable[tuple[str,
731838
weights = _create_fake_bias_for_k_proj(weights)
732839
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
733840

841+
@classmethod
842+
def validate_language(cls, language: str) -> bool:
843+
if language in ISO639_1_SUPPORTED_LANGS:
844+
return True
845+
elif language in ISO639_1_OTHER_LANGS:
846+
logger.warning(
847+
"The selected language %s has limited accuracy with"
848+
" reported WER>=0.5. Results may be less accurate "
849+
"for this choice.", language)
850+
return True
851+
else:
852+
raise ValueError(f"Unsupported language: {language}."
853+
"Language should be one of:" +
854+
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
855+
f"or {list(ISO639_1_OTHER_LANGS.values())}")
856+
857+
@classmethod
858+
def get_decoder_prompt(cls, language: str, task_type: str,
859+
prompt: str) -> str:
860+
return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
861+
f"<|notimestamps|>{prompt}")
862+
734863

735864
def _create_fake_bias_for_k_proj(
736865
weights: Iterable[tuple[str, torch.Tensor]]

0 commit comments

Comments
 (0)