-
Notifications
You must be signed in to change notification settings - Fork 278
[feat]: add transcription API endpoint using OpenAI Whisper-small #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
280b8b1
63cbd15
7954fc9
7ef270b
292fc62
4413d83
a34a431
b156254
518e453
d7cc2a3
cf11af6
89a403a
769e4f4
dc0f2d2
a57de3d
adcc64f
2d74fc2
7b058c8
6ebcee6
1e1ef45
f1522fc
73d5817
4e6edfa
ba769ee
cd4ebf7
996c653
0bd0129
f7ef2eb
3c66f96
6ac4661
37e33ca
1198490
10f1caa
4aa1a8b
114895b
12fcb2a
6c72f81
1a1c985
1d1b828
ae1dd95
856913c
cf2957a
96a038e
095d5d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,10 @@ | |
import os | ||
import time | ||
import uuid | ||
from typing import Optional | ||
|
||
import aiohttp | ||
from fastapi import BackgroundTasks, HTTPException, Request | ||
from fastapi import BackgroundTasks, HTTPException, Request, UploadFile | ||
from fastapi.responses import JSONResponse, StreamingResponse | ||
from requests import JSONDecodeError | ||
|
||
|
@@ -304,9 +305,7 @@ async def route_general_request( | |
async def send_request_to_prefiller( | ||
client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str | ||
): | ||
""" | ||
Send a request to a prefiller service. | ||
""" | ||
"""Send a request to a prefiller service.""" | ||
req_data = req_data.copy() | ||
req_data["max_tokens"] = 1 | ||
if "max_completion_tokens" in req_data: | ||
|
@@ -325,9 +324,7 @@ async def send_request_to_prefiller( | |
async def send_request_to_decode( | ||
client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str | ||
): | ||
""" | ||
Asynchronously stream the response from a service using a persistent client. | ||
""" | ||
"""Asynchronously stream the response from a service using a persistent client.""" | ||
headers = { | ||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
"X-Request-Id": request_id, | ||
|
@@ -511,3 +508,182 @@ async def route_sleep_wakeup_request( | |
content={"status": "success"}, | ||
headers={"X-Request-Id": request_id}, | ||
) | ||
|
||
|
||
async def route_general_transcriptions( | ||
request: Request, | ||
endpoint: str, # "/v1/audio/transcriptions" | ||
background_tasks: BackgroundTasks, | ||
): | ||
"""Handles audio transcription requests by parsing form data and proxying to backend.""" | ||
|
||
request_id = request.headers.get("X-Request-Id", str(uuid.uuid4())) | ||
|
||
# --- 1. Form parsing --- | ||
try: | ||
form = await request.form() | ||
|
||
# Extract parameters from the form data | ||
file: UploadFile = form["file"] | ||
model: str = form["model"] | ||
prompt: Optional[str] = form.get("prompt", None) | ||
response_format: Optional[str] = form.get("response_format", "json") | ||
temperature_str: Optional[str] = form.get("temperature", None) | ||
temperature: Optional[float] = ( | ||
float(temperature_str) if temperature_str is not None else None | ||
) | ||
language: Optional[str] = form.get("language", "en") | ||
except KeyError as e: | ||
return JSONResponse( | ||
status_code=400, | ||
content={"error": f"Invalid request: missing '{e.args[0]}' in form data."}, | ||
) | ||
|
||
logger.debug("==== Enter audio_transcriptions ====") | ||
logger.debug("Received upload: %s (%s)", file.filename, file.content_type) | ||
logger.debug( | ||
"Params: model=%s prompt=%r response_format=%r temperature=%r language=%s", | ||
model, | ||
prompt, | ||
response_format, | ||
temperature, | ||
language, | ||
) | ||
|
||
# --- 2. Service Discovery and Routing --- | ||
# Access singletons via request.app.state for consistent style | ||
service_discovery = ( | ||
get_service_discovery() | ||
) # This one is often still accessed directly via its get function | ||
router = request.app.state.router # Access router from app.state | ||
engine_stats_scraper = ( | ||
request.app.state.engine_stats_scraper | ||
) # Access engine_stats_scraper from app.state | ||
request_stats_monitor = ( | ||
request.app.state.request_stats_monitor | ||
) # Access request_stats_monitor from app.state | ||
|
||
endpoints = service_discovery.get_endpoint_info() | ||
|
||
logger.debug("==== Total endpoints ====") | ||
logger.debug(endpoints) | ||
logger.debug("==== Total endpoints ====") | ||
|
||
# filter the endpoints url by model name and label for transcriptions | ||
transcription_endpoints = [ | ||
ep | ||
for ep in endpoints | ||
if model == ep.model_name | ||
and ep.model_label == "transcription" | ||
and not ep.sleep # Added ep.sleep == False | ||
] | ||
Comment on lines
+573
to
+579
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sadly this crashes
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @davidgao7 Can you fix this error? |
||
|
||
logger.debug("====List of transcription endpoints====") | ||
logger.debug(transcription_endpoints) | ||
logger.debug("====List of transcription endpoints====") | ||
|
||
if not transcription_endpoints: | ||
logger.error("No transcription backend available for model %s", model) | ||
return JSONResponse( | ||
status_code=404, | ||
content={"error": f"No transcription backend for model {model}"}, | ||
) | ||
|
||
# grab the current engine and request stats | ||
engine_stats = engine_stats_scraper.get_engine_stats() | ||
request_stats = request_stats_monitor.get_request_stats(time.time()) | ||
|
||
# pick one using the router's configured logic (roundrobin, least-loaded, etc.) | ||
chosen_url = router.route_request( | ||
transcription_endpoints, | ||
engine_stats, | ||
request_stats, | ||
request, | ||
) | ||
|
||
logger.debug("Proxying transcription request to %s", chosen_url) | ||
|
||
# --- 3. Prepare and Proxy the Request --- | ||
payload_bytes = await file.read() | ||
files = {"file": (file.filename, payload_bytes, file.content_type)} | ||
|
||
data = {"model": model, "language": language} | ||
|
||
if prompt: | ||
data["prompt"] = prompt | ||
|
||
if response_format: | ||
data["response_format"] = response_format | ||
|
||
if temperature is not None: | ||
data["temperature"] = str(temperature) | ||
|
||
logger.info("Proxying transcription request for model %s to %s", model, chosen_url) | ||
|
||
logger.debug("==== data payload keys ====") | ||
logger.debug(list(data.keys())) | ||
logger.debug("==== data payload keys ====") | ||
|
||
try: | ||
client = request.app.state.aiohttp_client_wrapper() | ||
|
||
form_data = aiohttp.FormData() | ||
|
||
# add file data | ||
for key, (filename, content, content_type) in files.items(): | ||
form_data.add_field( | ||
key, content, filename=filename, content_type=content_type | ||
) | ||
|
||
# add from data | ||
for key, value in data.items(): | ||
form_data.add_field(key, value) | ||
|
||
backend_response = await client.post( | ||
f"{chosen_url}{endpoint}", | ||
data=form_data, | ||
timeout=aiohttp.ClientTimeout(total=300), | ||
) | ||
|
||
# --- 4. Return the response --- | ||
response_content = await backend_response.json() | ||
headers = { | ||
k: v | ||
for k, v in backend_response.headers.items() | ||
if k.lower() not in ("content-encoding", "transfer-encoding", "connection") | ||
} | ||
|
||
headers["X-Request-Id"] = request_id | ||
|
||
return JSONResponse( | ||
content=response_content, | ||
status_code=backend_response.status, | ||
headers=headers, | ||
) | ||
except aiohttp.ClientResponseError as response_error: | ||
if response_error.response is not None: | ||
try: | ||
error_content = await response_error.response.json() | ||
except ( | ||
aiohttp.ContentTypeError, | ||
json.JSONDecodeError, | ||
aiohttp.ClientError, | ||
): | ||
# If JSON parsing fails, get text content | ||
try: | ||
text_content = await response_error.response.text() | ||
error_content = {"error": text_content} | ||
except aiohttp.ClientError: | ||
error_content = { | ||
"error": f"HTTP {response_error.status}: {response_error.message}" | ||
} | ||
else: | ||
error_content = { | ||
"error": f"HTTP {response_error.status}: {response_error.message}" | ||
} | ||
return JSONResponse(status_code=response_error.status, content=error_content) | ||
except aiohttp.ClientError as client_error: | ||
return JSONResponse( | ||
status_code=503, | ||
content={"error": f"Failed to connect to backend: {str(client_error)}"}, | ||
) |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,8 +1,10 @@ | ||||||||
import abc | ||||||||
import enum | ||||||||
import io | ||||||||
import json | ||||||||
import re | ||||||||
import resource | ||||||||
import wave | ||||||||
from typing import Optional | ||||||||
|
||||||||
import requests | ||||||||
|
@@ -13,6 +15,23 @@ | |||||||
|
||||||||
logger = init_logger(__name__) | ||||||||
|
||||||||
# prepare a WAV byte to prevent repeatedly generating it | ||||||||
# Generate a 0.1 second silent audio file | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
# This will be used for the /v1/audio/transcriptions endpoint | ||||||||
_SILENT_WAV_BYTES = None | ||||||||
with io.BytesIO() as wav_buffer: | ||||||||
with wave.open(wav_buffer, "wb") as wf: | ||||||||
wf.setnchannels(1) # mono audio channel, standard configuration | ||||||||
wf.setsampwidth(2) # 16 bit audio, common bit depth for wav file | ||||||||
wf.setframerate(16000) # 16 kHz sample rate | ||||||||
wf.writeframes(b"\x00\x00" * 1600) # 0.1 second of silence | ||||||||
|
||||||||
# retrieves the generated wav bytes, return | ||||||||
_SILENT_WAV_BYTES = wav_buffer.getvalue() | ||||||||
logger.debug( | ||||||||
"======A default silent WAV file has been stored in memory within py application process====" | ||||||||
) | ||||||||
|
||||||||
|
||||||||
class SingletonMeta(type): | ||||||||
_instances = {} | ||||||||
|
@@ -52,6 +71,7 @@ class ModelType(enum.Enum): | |||||||
embeddings = "/v1/embeddings" | ||||||||
rerank = "/v1/rerank" | ||||||||
score = "/v1/score" | ||||||||
transcription = "/v1/audio/transcriptions" | ||||||||
|
||||||||
@staticmethod | ||||||||
def get_test_payload(model_type: str): | ||||||||
|
@@ -75,6 +95,12 @@ def get_test_payload(model_type: str): | |||||||
return {"query": "Hello", "documents": ["Test"]} | ||||||||
case ModelType.score: | ||||||||
return {"encoding_format": "float", "text_1": "Test", "test_2": "Test2"} | ||||||||
case ModelType.transcription: | ||||||||
if _SILENT_WAV_BYTES is not None: | ||||||||
logger.debug("=====Silent WAV Bytes is being used=====") | ||||||||
return { | ||||||||
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"), | ||||||||
} | ||||||||
|
||||||||
@staticmethod | ||||||||
def get_all_fields(): | ||||||||
|
@@ -161,14 +187,37 @@ def update_content_length(request: Request, request_body: str): | |||||||
|
||||||||
def is_model_healthy(url: str, model: str, model_type: str) -> bool: | ||||||||
model_details = ModelType[model_type] | ||||||||
|
||||||||
try: | ||||||||
response = requests.post( | ||||||||
f"{url}{model_details.value}", | ||||||||
headers={"Content-Type": "application/json"}, | ||||||||
json={"model": model} | model_details.get_test_payload(model_type), | ||||||||
timeout=30, | ||||||||
) | ||||||||
except Exception as e: | ||||||||
logger.error(e) | ||||||||
if model_type == "transcription": | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
# for transcription, the backend expects multipart/form-data with a file | ||||||||
# we will use pre-generated silent wav bytes | ||||||||
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")} | ||||||||
data = {"model": model} | ||||||||
response = requests.post( | ||||||||
f"{url}{model_details.value}", | ||||||||
files=files, # multipart/form-data | ||||||||
data=data, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should add a timeout here, otherwise the healthcheck thread will hang if the model hangs |
||||||||
timeout=10, | ||||||||
) | ||||||||
else: | ||||||||
# for other model types (chat, completion, etc.) | ||||||||
response = requests.post( | ||||||||
f"{url}{model_details.value}", | ||||||||
headers={"Content-Type": "application/json"}, | ||||||||
json={"model": model} | model_details.get_test_payload(model_type), | ||||||||
timeout=10, | ||||||||
) | ||||||||
|
||||||||
response.raise_for_status() | ||||||||
|
||||||||
if model_type == "transcription": | ||||||||
return True | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as far as I can see, the transcription model also returns valid json, no? But as mentioned below, its probably enough to just go for the status code check |
||||||||
else: | ||||||||
response.json() # verify it's valid json for other model types | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should return a boolean somewhere for the other models. Where is that done? Before we had There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't that be enough? |
||||||||
return True # validation passed | ||||||||
|
||||||||
except requests.exceptions.RequestException as e: | ||||||||
logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}") | ||||||||
return False | ||||||||
return response.status_code == 200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to align the parameter with the other endpoints? And also perhaps adding request-id would be a good one.