diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py index dc352b930..fb5e0afa2 100644 --- a/src/tests/test_utils.py +++ b/src/tests/test_utils.py @@ -104,6 +104,14 @@ def test_is_model_healthy_when_requests_raises_exception_returns_false( def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false( monkeypatch: pytest.MonkeyPatch, ) -> None: - request_mock = MagicMock(return_value=MagicMock(status_code=500)) + + # Mock an internal server error response + mock_response = MagicMock(status_code=500) + + # Tell the mock to raise an HTTP Error when raise_for_status() is called + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError + + request_mock = MagicMock(return_value=mock_response) monkeypatch.setattr("requests.post", request_mock) + assert utils.is_model_healthy("http://localhost", "test", "chat") is False diff --git a/src/vllm_router/routers/main_router.py b/src/vllm_router/routers/main_router.py index c13590440..5d77124dd 100644 --- a/src/vllm_router/routers/main_router.py +++ b/src/vllm_router/routers/main_router.py @@ -13,7 +13,11 @@ # limitations under the License. import json -from fastapi import APIRouter, BackgroundTasks, Request +from fastapi import ( + APIRouter, + BackgroundTasks, + Request, +) from fastapi.responses import JSONResponse, Response from vllm_router.dynamic_config import get_dynamic_config_watcher @@ -22,6 +26,7 @@ from vllm_router.service_discovery import get_service_discovery from vllm_router.services.request_service.request import ( route_general_request, + route_general_transcriptions, route_sleep_wakeup_request, ) from vllm_router.stats.engine_stats import get_engine_stats_scraper @@ -123,7 +128,7 @@ async def show_version(): @main_router.get("/v1/models") async def show_models(): """ - Returns a list of all models available in the stack + Returns a list of all models available in the stack. Args: None @@ -229,3 +234,13 @@ async def health() -> Response: ) else: return JSONResponse(content={"status": "healthy"}, status_code=200) + + +@main_router.post("/v1/audio/transcriptions") +async def route_v1_audio_transcriptions( + request: Request, background_tasks: BackgroundTasks +): + """Handles audio transcription requests.""" + return await route_general_transcriptions( + request, "/v1/audio/transcriptions", background_tasks + ) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 83e647927..0c5005715 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -17,9 +17,10 @@ import os import time import uuid +from typing import Optional import aiohttp -from fastapi import BackgroundTasks, HTTPException, Request +from fastapi import BackgroundTasks, HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse from requests import JSONDecodeError @@ -304,9 +305,7 @@ async def route_general_request( async def send_request_to_prefiller( client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str ): - """ - Send a request to a prefiller service. - """ + """Send a request to a prefiller service.""" req_data = req_data.copy() req_data["max_tokens"] = 1 if "max_completion_tokens" in req_data: @@ -325,9 +324,7 @@ async def send_request_to_prefiller( async def send_request_to_decode( client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str ): - """ - Asynchronously stream the response from a service using a persistent client. - """ + """Asynchronously stream the response from a service using a persistent client.""" headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id, @@ -511,3 +508,182 @@ async def route_sleep_wakeup_request( content={"status": "success"}, headers={"X-Request-Id": request_id}, ) + + +async def route_general_transcriptions( + request: Request, + endpoint: str, # "/v1/audio/transcriptions" + background_tasks: BackgroundTasks, +): + """Handles audio transcription requests by parsing form data and proxying to backend.""" + + request_id = request.headers.get("X-Request-Id", str(uuid.uuid4())) + + # --- 1. Form parsing --- + try: + form = await request.form() + + # Extract parameters from the form data + file: UploadFile = form["file"] + model: str = form["model"] + prompt: Optional[str] = form.get("prompt", None) + response_format: Optional[str] = form.get("response_format", "json") + temperature_str: Optional[str] = form.get("temperature", None) + temperature: Optional[float] = ( + float(temperature_str) if temperature_str is not None else None + ) + language: Optional[str] = form.get("language", "en") + except KeyError as e: + return JSONResponse( + status_code=400, + content={"error": f"Invalid request: missing '{e.args[0]}' in form data."}, + ) + + logger.debug("==== Enter audio_transcriptions ====") + logger.debug("Received upload: %s (%s)", file.filename, file.content_type) + logger.debug( + "Params: model=%s prompt=%r response_format=%r temperature=%r language=%s", + model, + prompt, + response_format, + temperature, + language, + ) + + # --- 2. Service Discovery and Routing --- + # Access singletons via request.app.state for consistent style + service_discovery = ( + get_service_discovery() + ) # This one is often still accessed directly via its get function + router = request.app.state.router # Access router from app.state + engine_stats_scraper = ( + request.app.state.engine_stats_scraper + ) # Access engine_stats_scraper from app.state + request_stats_monitor = ( + request.app.state.request_stats_monitor + ) # Access request_stats_monitor from app.state + + endpoints = service_discovery.get_endpoint_info() + + logger.debug("==== Total endpoints ====") + logger.debug(endpoints) + logger.debug("==== Total endpoints ====") + + # filter the endpoints url by model name and label for transcriptions + transcription_endpoints = [ + ep + for ep in endpoints + if model == ep.model_name + and ep.model_label == "transcription" + and not ep.sleep # Added ep.sleep == False + ] + + logger.debug("====List of transcription endpoints====") + logger.debug(transcription_endpoints) + logger.debug("====List of transcription endpoints====") + + if not transcription_endpoints: + logger.error("No transcription backend available for model %s", model) + return JSONResponse( + status_code=404, + content={"error": f"No transcription backend for model {model}"}, + ) + + # grab the current engine and request stats + engine_stats = engine_stats_scraper.get_engine_stats() + request_stats = request_stats_monitor.get_request_stats(time.time()) + + # pick one using the router's configured logic (roundrobin, least-loaded, etc.) + chosen_url = router.route_request( + transcription_endpoints, + engine_stats, + request_stats, + request, + ) + + logger.debug("Proxying transcription request to %s", chosen_url) + + # --- 3. Prepare and Proxy the Request --- + payload_bytes = await file.read() + files = {"file": (file.filename, payload_bytes, file.content_type)} + + data = {"model": model, "language": language} + + if prompt: + data["prompt"] = prompt + + if response_format: + data["response_format"] = response_format + + if temperature is not None: + data["temperature"] = str(temperature) + + logger.info("Proxying transcription request for model %s to %s", model, chosen_url) + + logger.debug("==== data payload keys ====") + logger.debug(list(data.keys())) + logger.debug("==== data payload keys ====") + + try: + client = request.app.state.aiohttp_client_wrapper() + + form_data = aiohttp.FormData() + + # add file data + for key, (filename, content, content_type) in files.items(): + form_data.add_field( + key, content, filename=filename, content_type=content_type + ) + + # add from data + for key, value in data.items(): + form_data.add_field(key, value) + + backend_response = await client.post( + f"{chosen_url}{endpoint}", + data=form_data, + timeout=aiohttp.ClientTimeout(total=300), + ) + + # --- 4. Return the response --- + response_content = await backend_response.json() + headers = { + k: v + for k, v in backend_response.headers.items() + if k.lower() not in ("content-encoding", "transfer-encoding", "connection") + } + + headers["X-Request-Id"] = request_id + + return JSONResponse( + content=response_content, + status_code=backend_response.status, + headers=headers, + ) + except aiohttp.ClientResponseError as response_error: + if response_error.response is not None: + try: + error_content = await response_error.response.json() + except ( + aiohttp.ContentTypeError, + json.JSONDecodeError, + aiohttp.ClientError, + ): + # If JSON parsing fails, get text content + try: + text_content = await response_error.response.text() + error_content = {"error": text_content} + except aiohttp.ClientError: + error_content = { + "error": f"HTTP {response_error.status}: {response_error.message}" + } + else: + error_content = { + "error": f"HTTP {response_error.status}: {response_error.message}" + } + return JSONResponse(status_code=response_error.status, content=error_content) + except aiohttp.ClientError as client_error: + return JSONResponse( + status_code=503, + content={"error": f"Failed to connect to backend: {str(client_error)}"}, + ) diff --git a/src/vllm_router/utils.py b/src/vllm_router/utils.py index 7e3f1c698..61a8aa557 100644 --- a/src/vllm_router/utils.py +++ b/src/vllm_router/utils.py @@ -1,8 +1,10 @@ import abc import enum +import io import json import re import resource +import wave from typing import Optional import requests @@ -13,6 +15,23 @@ logger = init_logger(__name__) +# prepare a WAV byte to prevent repeatedly generating it +# Generate a 0.1 second silent audio file +# This will be used for the /v1/audio/transcriptions endpoint +_SILENT_WAV_BYTES = None +with io.BytesIO() as wav_buffer: + with wave.open(wav_buffer, "wb") as wf: + wf.setnchannels(1) # mono audio channel, standard configuration + wf.setsampwidth(2) # 16 bit audio, common bit depth for wav file + wf.setframerate(16000) # 16 kHz sample rate + wf.writeframes(b"\x00\x00" * 1600) # 0.1 second of silence + + # retrieves the generated wav bytes, return + _SILENT_WAV_BYTES = wav_buffer.getvalue() + logger.debug( + "======A default silent WAV file has been stored in memory within py application process====" + ) + class SingletonMeta(type): _instances = {} @@ -52,6 +71,7 @@ class ModelType(enum.Enum): embeddings = "/v1/embeddings" rerank = "/v1/rerank" score = "/v1/score" + transcription = "/v1/audio/transcriptions" @staticmethod def get_test_payload(model_type: str): @@ -75,6 +95,12 @@ def get_test_payload(model_type: str): return {"query": "Hello", "documents": ["Test"]} case ModelType.score: return {"encoding_format": "float", "text_1": "Test", "test_2": "Test2"} + case ModelType.transcription: + if _SILENT_WAV_BYTES is not None: + logger.debug("=====Silent WAV Bytes is being used=====") + return { + "file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"), + } @staticmethod def get_all_fields(): @@ -161,14 +187,37 @@ def update_content_length(request: Request, request_body: str): def is_model_healthy(url: str, model: str, model_type: str) -> bool: model_details = ModelType[model_type] + try: - response = requests.post( - f"{url}{model_details.value}", - headers={"Content-Type": "application/json"}, - json={"model": model} | model_details.get_test_payload(model_type), - timeout=30, - ) - except Exception as e: - logger.error(e) + if model_type == "transcription": + + # for transcription, the backend expects multipart/form-data with a file + # we will use pre-generated silent wav bytes + files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")} + data = {"model": model} + response = requests.post( + f"{url}{model_details.value}", + files=files, # multipart/form-data + data=data, + timeout=10, + ) + else: + # for other model types (chat, completion, etc.) + response = requests.post( + f"{url}{model_details.value}", + headers={"Content-Type": "application/json"}, + json={"model": model} | model_details.get_test_payload(model_type), + timeout=10, + ) + + response.raise_for_status() + + if model_type == "transcription": + return True + else: + response.json() # verify it's valid json for other model types + return True # validation passed + + except requests.exceptions.RequestException as e: + logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}") return False - return response.status_code == 200 diff --git a/tutorials/23-whisper-api-transcription.md b/tutorials/23-whisper-api-transcription.md new file mode 100644 index 000000000..a09281916 --- /dev/null +++ b/tutorials/23-whisper-api-transcription.md @@ -0,0 +1,99 @@ +# Tutorial: Whisper Transcription API in vLLM Production Stack + +## Overview + +This tutorial introduces the newly added `/v1/audio/transcriptions` endpoint in the `vllm-router`, enabling users to transcribe `.wav` audio files using OpenAI’s `whisper-small` model. + +## Prerequisites + +* Access to a machine with a GPU (e.g. via [RunPod](https://runpod.io/)) +* Python 3.12 environment (recommended with `uv`) +* `vllm` and `production-stack` cloned and installed +* `vllm` installed with audio support: + + ```bash + pip install vllm[audio] + ``` + +## 1. Serving the Whisper Model + +Start a vLLM backend with the `whisper-small` model: + +```bash +vllm serve \ + --task transcription openai/whisper-small \ + --host 0.0.0.0 --port 8002 +``` + +## 2. Running the Router + +Create and run a router connected to the Whisper backend: + +```bash +#!/bin/bash +if [[ $# -ne 2 ]]; then + echo "Usage: $0 " + exit 1 +fi + +uv run python3 -m vllm_router.app \ + --host 0.0.0.0 --port "$1" \ + --service-discovery static \ + --static-backends "$2" \ + --static-models "openai/whisper-small" \ + --static-model-types "transcription" \ + --routing-logic roundrobin \ + --log-stats \ + --engine-stats-interval 10 \ + --request-stats-window 10 +``` + +Example usage: + +```bash +./run-router.sh 8000 http://localhost:8002 +``` + +## 3. Sending a Transcription Request + +Use `curl` to send a `.wav` file to the transcription endpoint: + +* You can test with any `.wav` audio file of your choice. + +```bash +curl -v http://localhost:8000/v1/audio/transcriptions \ + -F 'file=@/path/to/audio.wav;type=audio/wav' \ + -F 'model=openai/whisper-small' \ + -F 'response_format=json' \ + -F 'language=en' +``` + +### Supported Parameters + +| Parameter | Description | +| ----------------- | ------------------------------------------------------ | +| `file` | Path to a `.wav` audio file | +| `model` | Whisper model to use (e.g., `openai/whisper-small`) | +| `prompt` | *(Optional)* Text prompt to guide the transcription | +| `response_format` | One of `json`, `text`, `srt`, `verbose_json`, or `vtt` | +| `temperature` | *(Optional)* Sampling temperature as a float | +| `language` | ISO 639-1 code (e.g., `en`, `fr`, `zh`) | + +## 4. Sample Output + +```json +{ + "text": "Testing testing testing the whisper small model testing testing testing the audio transcription function testing testing testing the whisper small model" +} +``` + +## 5. Notes + +* Router uses extended aiohttp timeouts to support long transcription jobs. +* This implementation dynamically discovers valid transcription backends and routes requests accordingly. + +## 6. Resources + +* [PR #469 – Add Whisper Transcription API](https://github.com/vllm-project/production-stack/pull/469) +* [OpenAI Whisper GitHub](https://github.com/openai/whisper) +* [Blog: vLLM Whisper Transcription Walkthrough](https://davidgao7.github.io/posts/vllm-v1-whisper-transcription/)