Skip to content

Commit 9aa62fd

Browse files
committed
[feat]: add transcription API endpoint using OpenAI Whisper-small
Signed-off-by: David Gao <[email protected]>
1 parent 802250a commit 9aa62fd

File tree

3 files changed

+182
-17
lines changed

3 files changed

+182
-17
lines changed

src/vllm_router/routers/main_router.py

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import json
15-
16-
from fastapi import APIRouter, BackgroundTasks, Request
15+
import time
16+
17+
import httpx
18+
from fastapi import (
19+
APIRouter,
20+
BackgroundTasks,
21+
File,
22+
Form,
23+
HTTPException,
24+
Request,
25+
UploadFile,
26+
)
1727
from fastapi.responses import JSONResponse, Response
1828

1929
from vllm_router.dynamic_config import get_dynamic_config_watcher
2030
from vllm_router.log import init_logger
2131
from vllm_router.protocols import ModelCard, ModelList
32+
from vllm_router.routers.routing_logic import get_routing_logic
2233
from vllm_router.service_discovery import get_service_discovery
2334
from vllm_router.services.request_service.request import route_general_request
2435
from vllm_router.stats.engine_stats import get_engine_stats_scraper
36+
from vllm_router.stats.request_stats import RequestStatsMonitor
2537
from vllm_router.version import __version__
2638

2739
try:
@@ -139,8 +151,7 @@ async def show_models():
139151

140152
@main_router.get("/health")
141153
async def health() -> Response:
142-
"""
143-
Endpoint to check the health status of various components.
154+
"""Endpoint to check the health status of various components.
144155
145156
This function verifies the health of the service discovery module and
146157
the engine stats scraper. If either component is down, it returns a
@@ -173,3 +184,135 @@ async def health() -> Response:
173184
)
174185
else:
175186
return JSONResponse(content={"status": "healthy"}, status_code=200)
187+
188+
189+
@main_router.post("/v1/audio/transcriptions")
190+
async def audio_transcriptions(
191+
file: UploadFile = File(...),
192+
model: str = Form(...),
193+
prompt: str | None = Form(None),
194+
response_format: str | None = Form("json"),
195+
temperature: float | None = Form(None),
196+
language: str = Form("en"),
197+
):
198+
199+
logger.debug("==== Enter audio_transcriptions ====")
200+
logger.debug("Received upload: %s (%s)", file.filename, file.content_type)
201+
logger.debug(
202+
"Params: model=%s prompt=%r response_format=%r temperature=%r language=%s",
203+
model,
204+
prompt,
205+
response_format,
206+
temperature,
207+
language,
208+
)
209+
210+
# read file bytes
211+
payload_bytes = await file.read()
212+
files = {
213+
"file": (file.filename, payload_bytes, file.content_type),
214+
}
215+
# logger.debug("=========files=========")
216+
# logger.debug(files)
217+
# logger.debug("=========files=========")
218+
219+
data = {
220+
"model": model,
221+
"language": language,
222+
}
223+
224+
if prompt:
225+
data["prompt"] = prompt
226+
227+
if response_format:
228+
data["response_format"] = response_format
229+
230+
if temperature is not None:
231+
data["temperature"] = str(temperature)
232+
233+
logger.debug("==== data payload keys ====")
234+
logger.debug(list(data.keys()))
235+
logger.debug("==== data payload keys ====")
236+
237+
# get the backend url
238+
endpoints = get_service_discovery().get_endpoint_info()
239+
240+
logger.debug("==== Total endpoints ====")
241+
logger.debug(endpoints)
242+
logger.debug("==== Total endpoints ====")
243+
244+
# TODO: right now is skipping label check in code for local testing
245+
endpoints = [
246+
ep
247+
for ep in endpoints
248+
if model in ep.model_names # that actually serve your model
249+
]
250+
251+
logger.debug("==== Discovered endpoints after filtering ====")
252+
logger.debug(endpoints)
253+
logger.debug("==== Discovered endpoints after filtering ====")
254+
255+
# filter the endpoints url for transcriptions
256+
transcription_endpoints = [ep for ep in endpoints if model in ep.model_names]
257+
258+
logger.debug("====List of transcription endpoints====")
259+
logger.debug(transcription_endpoints)
260+
logger.debug("====List of transcription endpoints====")
261+
262+
if not transcription_endpoints:
263+
logger.error("No transcription backend available for model %s", model)
264+
raise HTTPException(
265+
status_code=503, detail=f"No transcription backend for model {model}"
266+
)
267+
268+
# grab the current engin and request stats
269+
engine_stats = get_engine_stats_scraper().get_engine_stats()
270+
request_stats = RequestStatsMonitor().get_request_stats(time.time())
271+
router = get_routing_logic()
272+
273+
# pick one using the router's configured logic (roundrobin, least-loaded, etc.)
274+
chosen_url = router.route_request(
275+
transcription_endpoints,
276+
engine_stats,
277+
request_stats,
278+
# we don’t need to pass the original FastAPI Request object here,
279+
# but you can if your routing logic looks at headers or body
280+
None,
281+
)
282+
283+
logger.info("Proxying transcription request to %s", chosen_url)
284+
285+
# proxy the request
286+
# by default httpx will only wait for 5 seconds, large audio transcriptions generally
287+
# take longer than that
288+
async with httpx.AsyncClient(
289+
base_url=chosen_url,
290+
timeout=httpx.Timeout(
291+
connect=60.0, # connect timeout
292+
read=300.0, # read timeout
293+
write=30.0, # if you’re streaming uploads
294+
pool=None, # no pool timeout
295+
),
296+
) as client:
297+
logger.debug("Sending multipart to %s/v1/audio/transcriptions …", chosen_url)
298+
proxied = await client.post("/v1/audio/transcriptions", data=data, files=files)
299+
logger.info("Received %d from whisper backend", proxied.status_code)
300+
301+
# return the whisper response unmodified
302+
resp = proxied.json()
303+
logger.debug("==== Whisper response payload ====")
304+
logger.debug(resp)
305+
logger.debug("==== Whisper response payload ====")
306+
307+
logger.debug("Backend response headers: %s", proxied.headers)
308+
logger.debug("Backend response body (truncated): %r", proxied.content[:200])
309+
310+
return JSONResponse(
311+
content=resp,
312+
status_code=proxied.status_code,
313+
headers={
314+
k: v
315+
for k, v in proxied.headers.items()
316+
if k.lower() not in ("content-encoding", "transfer-encoding", "connection")
317+
},
318+
)

src/vllm_router/run-router.sh

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
2-
if [[ $# -ne 1 ]]; then
3-
echo "Usage $0 <router port>"
2+
if [[ $# -ne 2 ]]; then
3+
echo "Usage $0 <router port> <backend url>"
44
exit 1
55
fi
66

@@ -15,17 +15,17 @@ fi
1515
# --log-stats
1616

1717
# Use this command when testing with static service discovery
18-
python3 -m vllm_router.app --port "$1" \
19-
--service-discovery static \
20-
--static-backends "http://localhost:8000" \
21-
--static-models "facebook/opt-125m" \
22-
--static-model-types "chat" \
23-
--log-stats \
24-
--log-stats-interval 10 \
25-
--engine-stats-interval 10 \
26-
--request-stats-window 10 \
27-
--request-stats-window 10 \
28-
--routing-logic roundrobin
18+
# python3 -m vllm_router.app --port "$1" \
19+
# --service-discovery static \
20+
# --static-backends "http://localhost:8000" \
21+
# --static-models "facebook/opt-125m" \
22+
# --static-model-types "chat" \
23+
# --log-stats \
24+
# --log-stats-interval 10 \
25+
# --engine-stats-interval 10 \
26+
# --request-stats-window 10 \
27+
# --request-stats-window 10 \
28+
# --routing-logic roundrobin
2929

3030
# Use this command when testing with roundrobin routing logic
3131
#python3 router.py --port "$1" \
@@ -35,3 +35,19 @@ python3 -m vllm_router.app --port "$1" \
3535
# --engine-stats-interval 10 \
3636
# --log-stats
3737
#
38+
39+
# Use this command when testing with whisper transcription
40+
ROUTER_PORT=$1
41+
BACKEND_URL=$2
42+
43+
python3 -m vllm_router.app \
44+
--host 0.0.0.0 \
45+
--port "${ROUTER_PORT}" \
46+
--service-discovery static \
47+
--static-backends "${BACKEND_URL}" \
48+
--static-models "openai/whisper-small" \
49+
--static-model-types "transcription" \
50+
--routing-logic roundrobin \
51+
--log-stats \
52+
--engine-stats-interval 10 \
53+
--request-stats-window 10

src/vllm_router/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ModelType(enum.Enum):
5151
embeddings = "/v1/embeddings"
5252
rerank = "/v1/rerank"
5353
score = "/v1/score"
54+
transcription = "/v1/audio/transcriptions"
5455

5556
@staticmethod
5657
def get_test_payload(model_type: str):
@@ -75,6 +76,11 @@ def get_test_payload(model_type: str):
7576
return {"query": "Hello", "documents": ["Test"]}
7677
case ModelType.score:
7778
return {"encoding_format": "float", "text_1": "Test", "test_2": "Test2"}
79+
case ModelType.transcription:
80+
return {
81+
"file": "",
82+
"model": "openai/whisper-small",
83+
}
7884

7985
@staticmethod
8086
def get_all_fields():

0 commit comments

Comments
 (0)