|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import json
|
15 |
| - |
16 |
| -from fastapi import APIRouter, BackgroundTasks, Request |
| 15 | +import time |
| 16 | + |
| 17 | +import httpx |
| 18 | +from fastapi import ( |
| 19 | + APIRouter, |
| 20 | + BackgroundTasks, |
| 21 | + File, |
| 22 | + Form, |
| 23 | + HTTPException, |
| 24 | + Request, |
| 25 | + UploadFile, |
| 26 | +) |
17 | 27 | from fastapi.responses import JSONResponse, Response
|
18 | 28 |
|
19 | 29 | from vllm_router.dynamic_config import get_dynamic_config_watcher
|
20 | 30 | from vllm_router.log import init_logger
|
21 | 31 | from vllm_router.protocols import ModelCard, ModelList
|
| 32 | +from vllm_router.routers.routing_logic import get_routing_logic |
22 | 33 | from vllm_router.service_discovery import get_service_discovery
|
23 | 34 | from vllm_router.services.request_service.request import route_general_request
|
24 | 35 | from vllm_router.stats.engine_stats import get_engine_stats_scraper
|
| 36 | +from vllm_router.stats.request_stats import RequestStatsMonitor |
25 | 37 | from vllm_router.version import __version__
|
26 | 38 |
|
27 | 39 | try:
|
@@ -139,8 +151,7 @@ async def show_models():
|
139 | 151 |
|
140 | 152 | @main_router.get("/health")
|
141 | 153 | async def health() -> Response:
|
142 |
| - """ |
143 |
| - Endpoint to check the health status of various components. |
| 154 | + """Endpoint to check the health status of various components. |
144 | 155 |
|
145 | 156 | This function verifies the health of the service discovery module and
|
146 | 157 | the engine stats scraper. If either component is down, it returns a
|
@@ -173,3 +184,135 @@ async def health() -> Response:
|
173 | 184 | )
|
174 | 185 | else:
|
175 | 186 | return JSONResponse(content={"status": "healthy"}, status_code=200)
|
| 187 | + |
| 188 | + |
| 189 | +@main_router.post("/v1/audio/transcriptions") |
| 190 | +async def audio_transcriptions( |
| 191 | + file: UploadFile = File(...), |
| 192 | + model: str = Form(...), |
| 193 | + prompt: str | None = Form(None), |
| 194 | + response_format: str | None = Form("json"), |
| 195 | + temperature: float | None = Form(None), |
| 196 | + language: str = Form("en"), |
| 197 | +): |
| 198 | + |
| 199 | + logger.debug("==== Enter audio_transcriptions ====") |
| 200 | + logger.debug("Received upload: %s (%s)", file.filename, file.content_type) |
| 201 | + logger.debug( |
| 202 | + "Params: model=%s prompt=%r response_format=%r temperature=%r language=%s", |
| 203 | + model, |
| 204 | + prompt, |
| 205 | + response_format, |
| 206 | + temperature, |
| 207 | + language, |
| 208 | + ) |
| 209 | + |
| 210 | + # read file bytes |
| 211 | + payload_bytes = await file.read() |
| 212 | + files = { |
| 213 | + "file": (file.filename, payload_bytes, file.content_type), |
| 214 | + } |
| 215 | + # logger.debug("=========files=========") |
| 216 | + # logger.debug(files) |
| 217 | + # logger.debug("=========files=========") |
| 218 | + |
| 219 | + data = { |
| 220 | + "model": model, |
| 221 | + "language": language, |
| 222 | + } |
| 223 | + |
| 224 | + if prompt: |
| 225 | + data["prompt"] = prompt |
| 226 | + |
| 227 | + if response_format: |
| 228 | + data["response_format"] = response_format |
| 229 | + |
| 230 | + if temperature is not None: |
| 231 | + data["temperature"] = str(temperature) |
| 232 | + |
| 233 | + logger.debug("==== data payload keys ====") |
| 234 | + logger.debug(list(data.keys())) |
| 235 | + logger.debug("==== data payload keys ====") |
| 236 | + |
| 237 | + # get the backend url |
| 238 | + endpoints = get_service_discovery().get_endpoint_info() |
| 239 | + |
| 240 | + logger.debug("==== Total endpoints ====") |
| 241 | + logger.debug(endpoints) |
| 242 | + logger.debug("==== Total endpoints ====") |
| 243 | + |
| 244 | + # TODO: right now is skipping label check in code for local testing |
| 245 | + endpoints = [ |
| 246 | + ep |
| 247 | + for ep in endpoints |
| 248 | + if model in ep.model_names # that actually serve your model |
| 249 | + ] |
| 250 | + |
| 251 | + logger.debug("==== Discovered endpoints after filtering ====") |
| 252 | + logger.debug(endpoints) |
| 253 | + logger.debug("==== Discovered endpoints after filtering ====") |
| 254 | + |
| 255 | + # filter the endpoints url for transcriptions |
| 256 | + transcription_endpoints = [ep for ep in endpoints if model in ep.model_names] |
| 257 | + |
| 258 | + logger.debug("====List of transcription endpoints====") |
| 259 | + logger.debug(transcription_endpoints) |
| 260 | + logger.debug("====List of transcription endpoints====") |
| 261 | + |
| 262 | + if not transcription_endpoints: |
| 263 | + logger.error("No transcription backend available for model %s", model) |
| 264 | + raise HTTPException( |
| 265 | + status_code=503, detail=f"No transcription backend for model {model}" |
| 266 | + ) |
| 267 | + |
| 268 | + # grab the current engin and request stats |
| 269 | + engine_stats = get_engine_stats_scraper().get_engine_stats() |
| 270 | + request_stats = RequestStatsMonitor().get_request_stats(time.time()) |
| 271 | + router = get_routing_logic() |
| 272 | + |
| 273 | + # pick one using the router's configured logic (roundrobin, least-loaded, etc.) |
| 274 | + chosen_url = router.route_request( |
| 275 | + transcription_endpoints, |
| 276 | + engine_stats, |
| 277 | + request_stats, |
| 278 | + # we don’t need to pass the original FastAPI Request object here, |
| 279 | + # but you can if your routing logic looks at headers or body |
| 280 | + None, |
| 281 | + ) |
| 282 | + |
| 283 | + logger.info("Proxying transcription request to %s", chosen_url) |
| 284 | + |
| 285 | + # proxy the request |
| 286 | + # by default httpx will only wait for 5 seconds, large audio transcriptions generally |
| 287 | + # take longer than that |
| 288 | + async with httpx.AsyncClient( |
| 289 | + base_url=chosen_url, |
| 290 | + timeout=httpx.Timeout( |
| 291 | + connect=60.0, # connect timeout |
| 292 | + read=300.0, # read timeout |
| 293 | + write=30.0, # if you’re streaming uploads |
| 294 | + pool=None, # no pool timeout |
| 295 | + ), |
| 296 | + ) as client: |
| 297 | + logger.debug("Sending multipart to %s/v1/audio/transcriptions …", chosen_url) |
| 298 | + proxied = await client.post("/v1/audio/transcriptions", data=data, files=files) |
| 299 | + logger.info("Received %d from whisper backend", proxied.status_code) |
| 300 | + |
| 301 | + # return the whisper response unmodified |
| 302 | + resp = proxied.json() |
| 303 | + logger.debug("==== Whisper response payload ====") |
| 304 | + logger.debug(resp) |
| 305 | + logger.debug("==== Whisper response payload ====") |
| 306 | + |
| 307 | + logger.debug("Backend response headers: %s", proxied.headers) |
| 308 | + logger.debug("Backend response body (truncated): %r", proxied.content[:200]) |
| 309 | + |
| 310 | + return JSONResponse( |
| 311 | + content=resp, |
| 312 | + status_code=proxied.status_code, |
| 313 | + headers={ |
| 314 | + k: v |
| 315 | + for k, v in proxied.headers.items() |
| 316 | + if k.lower() not in ("content-encoding", "transfer-encoding", "connection") |
| 317 | + }, |
| 318 | + ) |
0 commit comments