Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
280b8b1
[feat]: add transcription API endpoint using OpenAI Whisper-small
davidgao7 May 30, 2025
63cbd15
remove the whisper payload response log
davidgao7 Jun 3, 2025
7954fc9
[docs]: add tutorial for transcription v1 api
davidgao7 Jun 3, 2025
7ef270b
[chore] align example router running script with main
davidgao7 Jun 3, 2025
292fc62
omit model field since backend already knows which model to run
davidgao7 Jun 12, 2025
4413d83
generate a silent audio file if no audio file appears
davidgao7 Jun 18, 2025
a34a431
put wav creation at the module level to prevent being recreated every…
davidgao7 Jun 18, 2025
b156254
[Test] test frequency of silent audio creation
davidgao7 Jun 19, 2025
518e453
send multipart/form-data for transcription model's health check
davidgao7 Jun 21, 2025
d7cc2a3
fix pre-commit issue
davidgao7 Jun 23, 2025
cf11af6
Moves the implementation for the `/v1/audio/transcriptions` endpoint …
davidgao7 Jun 28, 2025
89a403a
Merge branch 'main' into pr/transcription-whisper
davidgao7 Jul 12, 2025
769e4f4
add timeout to ensure health check will not hang indefinitely if a ba…
davidgao7 Jul 24, 2025
dc0f2d2
add boolean model health check return for non-transcription model
davidgao7 Jul 24, 2025
a57de3d
remove redundant warning log since handled in outer 'StaticServiceDis…
davidgao7 Jul 24, 2025
adcc64f
remove redundant JSONDecodeError catch and downgrade RequestException…
davidgao7 Jul 24, 2025
2d74fc2
Merge branch 'main' into pr/transcription-whisper
davidgao7 Jul 25, 2025
7b058c8
Merge branch 'main' into pr/transcription-whisper
davidgao7 Jul 26, 2025
6ebcee6
Merge branch 'main' into pr/transcription-whisper
davidgao7 Jul 29, 2025
1e1ef45
Merge branch 'main' into pr/transcription-whisper
davidgao7 Jul 30, 2025
f1522fc
Chore: Apply auto-formatting and linting fixes via pre-commit
davidgao7 Jul 30, 2025
73d5817
refactor: update more meaningful comments for silent wav bytes genera…
davidgao7 Jul 30, 2025
4e6edfa
Merge branch 'main' into pr/transcription-whisper
davidgao7 Jul 30, 2025
ba769ee
refactor: keep the comment to explain purpose for generating a silent
davidgao7 Jul 30, 2025
cd4ebf7
Merge remote-tracking branch 'origin/pr/transcription-whisper' into p…
davidgao7 Jul 30, 2025
996c653
fix(tests): Improve mock in model health check test
davidgao7 Aug 11, 2025
0bd0129
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 11, 2025
f7ef2eb
Chore: Apply auto-formatting and linting fixes via pre-commit
davidgao7 Aug 12, 2025
3c66f96
chore: remove unused var `in_router_time`
davidgao7 Aug 13, 2025
6ac4661
fix: (deps) add httpx as an explicit dependency
davidgao7 Aug 14, 2025
37e33ca
chore: dependencies order changes after running pre-commit
davidgao7 Aug 14, 2025
1198490
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 18, 2025
10f1caa
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 20, 2025
4aa1a8b
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 21, 2025
114895b
refactor: Migration from httpx to aiohttp for improved concurrency
davidgao7 Aug 22, 2025
12fcb2a
chore: remove wrong tutorial file
davidgao7 Aug 22, 2025
6c72f81
chore: apply pre-commit
davidgao7 Aug 22, 2025
1a1c985
chore: use debug log print
davidgao7 Aug 22, 2025
1d1b828
chore: change to more specific exception handling for aiohttp
davidgao7 Aug 22, 2025
ae1dd95
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 22, 2025
856913c
Merge branch 'main' into pr/transcription-whisper
YuhanLiu11 Aug 25, 2025
cf2957a
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 26, 2025
96a038e
Merge branch 'main' into pr/transcription-whisper
davidgao7 Aug 28, 2025
095d5d0
Merge branch 'main' into pr/transcription-whisper
davidgao7 Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ def test_is_model_healthy_when_requests_raises_exception_returns_false(
def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false(
monkeypatch: pytest.MonkeyPatch,
) -> None:
request_mock = MagicMock(return_value=MagicMock(status_code=500))

# Mock an internal server error response
mock_response = MagicMock(status_code=500)

# Tell the mock to raise an HTTP Error when raise_for_status() is called
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError

request_mock = MagicMock(return_value=mock_response)
monkeypatch.setattr("requests.post", request_mock)

assert utils.is_model_healthy("http://localhost", "test", "chat") is False
19 changes: 17 additions & 2 deletions src/vllm_router/routers/main_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
import json

from fastapi import APIRouter, BackgroundTasks, Request
from fastapi import (
APIRouter,
BackgroundTasks,
Request,
)
from fastapi.responses import JSONResponse, Response

from vllm_router.dynamic_config import get_dynamic_config_watcher
Expand All @@ -22,6 +26,7 @@
from vllm_router.service_discovery import get_service_discovery
from vllm_router.services.request_service.request import (
route_general_request,
route_general_transcriptions,
route_sleep_wakeup_request,
)
from vllm_router.stats.engine_stats import get_engine_stats_scraper
Expand Down Expand Up @@ -123,7 +128,7 @@ async def show_version():
@main_router.get("/v1/models")
async def show_models():
"""
Returns a list of all models available in the stack
Returns a list of all models available in the stack.

Args:
None
Expand Down Expand Up @@ -229,3 +234,13 @@ async def health() -> Response:
)
else:
return JSONResponse(content={"status": "healthy"}, status_code=200)


@main_router.post("/v1/audio/transcriptions")
async def route_v1_audio_transcriptions(
request: Request, background_tasks: BackgroundTasks
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to align the parameter with the other endpoints? And also perhaps adding request-id would be a good one.

"""Handles audio transcription requests."""
return await route_general_transcriptions(
request, "/v1/audio/transcriptions", background_tasks
)
190 changes: 183 additions & 7 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import os
import time
import uuid
from typing import Optional

import aiohttp
from fastapi import BackgroundTasks, HTTPException, Request
from fastapi import BackgroundTasks, HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from requests import JSONDecodeError

Expand Down Expand Up @@ -304,9 +305,7 @@ async def route_general_request(
async def send_request_to_prefiller(
client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str
):
"""
Send a request to a prefiller service.
"""
"""Send a request to a prefiller service."""
req_data = req_data.copy()
req_data["max_tokens"] = 1
if "max_completion_tokens" in req_data:
Expand All @@ -325,9 +324,7 @@ async def send_request_to_prefiller(
async def send_request_to_decode(
client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str
):
"""
Asynchronously stream the response from a service using a persistent client.
"""
"""Asynchronously stream the response from a service using a persistent client."""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
Expand Down Expand Up @@ -511,3 +508,182 @@ async def route_sleep_wakeup_request(
content={"status": "success"},
headers={"X-Request-Id": request_id},
)


async def route_general_transcriptions(
request: Request,
endpoint: str, # "/v1/audio/transcriptions"
background_tasks: BackgroundTasks,
):
"""Handles audio transcription requests by parsing form data and proxying to backend."""

request_id = request.headers.get("X-Request-Id", str(uuid.uuid4()))

# --- 1. Form parsing ---
try:
form = await request.form()

# Extract parameters from the form data
file: UploadFile = form["file"]
model: str = form["model"]
prompt: Optional[str] = form.get("prompt", None)
response_format: Optional[str] = form.get("response_format", "json")
temperature_str: Optional[str] = form.get("temperature", None)
temperature: Optional[float] = (
float(temperature_str) if temperature_str is not None else None
)
language: Optional[str] = form.get("language", "en")
except KeyError as e:
return JSONResponse(
status_code=400,
content={"error": f"Invalid request: missing '{e.args[0]}' in form data."},
)

logger.debug("==== Enter audio_transcriptions ====")
logger.debug("Received upload: %s (%s)", file.filename, file.content_type)
logger.debug(
"Params: model=%s prompt=%r response_format=%r temperature=%r language=%s",
model,
prompt,
response_format,
temperature,
language,
)

# --- 2. Service Discovery and Routing ---
# Access singletons via request.app.state for consistent style
service_discovery = (
get_service_discovery()
) # This one is often still accessed directly via its get function
router = request.app.state.router # Access router from app.state
engine_stats_scraper = (
request.app.state.engine_stats_scraper
) # Access engine_stats_scraper from app.state
request_stats_monitor = (
request.app.state.request_stats_monitor
) # Access request_stats_monitor from app.state

endpoints = service_discovery.get_endpoint_info()

logger.debug("==== Total endpoints ====")
logger.debug(endpoints)
logger.debug("==== Total endpoints ====")

# filter the endpoints url by model name and label for transcriptions
transcription_endpoints = [
ep
for ep in endpoints
if model == ep.model_name
and ep.model_label == "transcription"
and not ep.sleep # Added ep.sleep == False
]

logger.debug("====List of transcription endpoints====")
logger.debug(transcription_endpoints)
logger.debug("====List of transcription endpoints====")

if not transcription_endpoints:
logger.error("No transcription backend available for model %s", model)
return JSONResponse(
status_code=404,
content={"error": f"No transcription backend for model {model}"},
)

# grab the current engine and request stats
engine_stats = engine_stats_scraper.get_engine_stats()
request_stats = request_stats_monitor.get_request_stats(time.time())

# pick one using the router's configured logic (roundrobin, least-loaded, etc.)
chosen_url = router.route_request(
transcription_endpoints,
engine_stats,
request_stats,
request,
)

logger.debug("Proxying transcription request to %s", chosen_url)

# --- 3. Prepare and Proxy the Request ---
payload_bytes = await file.read()
files = {"file": (file.filename, payload_bytes, file.content_type)}

data = {"model": model, "language": language}

if prompt:
data["prompt"] = prompt

if response_format:
data["response_format"] = response_format

if temperature is not None:
data["temperature"] = str(temperature)

logger.info("Proxying transcription request for model %s to %s", model, chosen_url)

logger.debug("==== data payload keys ====")
logger.debug(list(data.keys()))
logger.debug("==== data payload keys ====")

try:
client = request.app.state.aiohttp_client_wrapper()

form_data = aiohttp.FormData()

# add file data
for key, (filename, content, content_type) in files.items():
form_data.add_field(
key, content, filename=filename, content_type=content_type
)

# add from data
for key, value in data.items():
form_data.add_field(key, value)

backend_response = await client.post(
f"{chosen_url}{endpoint}",
data=form_data,
timeout=aiohttp.ClientTimeout(total=300),
)

# --- 4. Return the response ---
response_content = await backend_response.json()
headers = {
k: v
for k, v in backend_response.headers.items()
if k.lower() not in ("content-encoding", "transfer-encoding", "connection")
}

headers["X-Request-Id"] = request_id

return JSONResponse(
content=response_content,
status_code=backend_response.status,
headers=headers,
)
except aiohttp.ClientResponseError as response_error:
if response_error.response is not None:
try:
error_content = await response_error.response.json()
except (
aiohttp.ContentTypeError,
json.JSONDecodeError,
aiohttp.ClientError,
):
# If JSON parsing fails, get text content
try:
text_content = await response_error.response.text()
error_content = {"error": text_content}
except aiohttp.ClientError:
error_content = {
"error": f"HTTP {response_error.status}: {response_error.message}"
}
else:
error_content = {
"error": f"HTTP {response_error.status}: {response_error.message}"
}
return JSONResponse(status_code=response_error.status, content=error_content)
except aiohttp.ClientError as client_error:
return JSONResponse(
status_code=503,
content={"error": f"Failed to connect to backend: {str(client_error)}"},
)
67 changes: 58 additions & 9 deletions src/vllm_router/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
import enum
import io
import json
import re
import resource
import wave
from typing import Optional

import requests
Expand All @@ -13,6 +15,23 @@

logger = init_logger(__name__)

# prepare a WAV byte to prevent repeatedly generating it
# Generate a 0.1 second silent audio file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Generate a 0.1 second silent audio file
# Generate a 0.1 second silent audio file
# This will be used for the /v1/audio/transcriptions endpoint

# This will be used for the /v1/audio/transcriptions endpoint
_SILENT_WAV_BYTES = None
with io.BytesIO() as wav_buffer:
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(1) # mono audio channel, standard configuration
wf.setsampwidth(2) # 16 bit audio, common bit depth for wav file
wf.setframerate(16000) # 16 kHz sample rate
wf.writeframes(b"\x00\x00" * 1600) # 0.1 second of silence

# retrieves the generated wav bytes, return
_SILENT_WAV_BYTES = wav_buffer.getvalue()
logger.debug(
"======A default silent WAV file has been stored in memory within py application process===="
)


class SingletonMeta(type):
_instances = {}
Expand Down Expand Up @@ -52,6 +71,7 @@ class ModelType(enum.Enum):
embeddings = "/v1/embeddings"
rerank = "/v1/rerank"
score = "/v1/score"
transcription = "/v1/audio/transcriptions"

@staticmethod
def get_test_payload(model_type: str):
Expand All @@ -75,6 +95,12 @@ def get_test_payload(model_type: str):
return {"query": "Hello", "documents": ["Test"]}
case ModelType.score:
return {"encoding_format": "float", "text_1": "Test", "test_2": "Test2"}
case ModelType.transcription:
if _SILENT_WAV_BYTES is not None:
logger.debug("=====Silent WAV Bytes is being used=====")
return {
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"),
}

@staticmethod
def get_all_fields():
Expand Down Expand Up @@ -161,14 +187,37 @@ def update_content_length(request: Request, request_body: str):

def is_model_healthy(url: str, model: str, model_type: str) -> bool:
model_details = ModelType[model_type]

try:
response = requests.post(
f"{url}{model_details.value}",
headers={"Content-Type": "application/json"},
json={"model": model} | model_details.get_test_payload(model_type),
timeout=30,
)
except Exception as e:
logger.error(e)
if model_type == "transcription":

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

# for transcription, the backend expects multipart/form-data with a file
# we will use pre-generated silent wav bytes
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")}
data = {"model": model}
response = requests.post(
f"{url}{model_details.value}",
files=files, # multipart/form-data
data=data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a timeout here, otherwise the healthcheck thread will hang if the model hangs

timeout=10,
)
else:
# for other model types (chat, completion, etc.)
response = requests.post(
f"{url}{model_details.value}",
headers={"Content-Type": "application/json"},
json={"model": model} | model_details.get_test_payload(model_type),
timeout=10,
)

response.raise_for_status()

if model_type == "transcription":
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I can see, the transcription model also returns valid json, no?

But as mentioned below, its probably enough to just go for the status code check

else:
response.json() # verify it's valid json for other model types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return a boolean somewhere for the other models. Where is that done?

Before we had return response.status_code == 200

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't that be enough?

return True # validation passed

except requests.exceptions.RequestException as e:
logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}")
return False
return response.status_code == 200
Loading