-
Notifications
You must be signed in to change notification settings - Fork 277
[feat]: add transcription API endpoint using OpenAI Whisper-small #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
280b8b1
63cbd15
7954fc9
7ef270b
292fc62
4413d83
a34a431
b156254
518e453
d7cc2a3
cf11af6
89a403a
769e4f4
dc0f2d2
a57de3d
adcc64f
2d74fc2
7b058c8
6ebcee6
1e1ef45
f1522fc
73d5817
4e6edfa
ba769ee
cd4ebf7
996c653
0bd0129
f7ef2eb
3c66f96
6ac4661
37e33ca
1198490
10f1caa
4aa1a8b
114895b
12fcb2a
6c72f81
1a1c985
1d1b828
ae1dd95
856913c
cf2957a
96a038e
095d5d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,8 +1,10 @@ | ||||||||
import abc | ||||||||
import enum | ||||||||
import io | ||||||||
import json | ||||||||
import re | ||||||||
import resource | ||||||||
import wave | ||||||||
from typing import Optional | ||||||||
|
||||||||
import requests | ||||||||
|
@@ -13,6 +15,23 @@ | |||||||
|
||||||||
logger = init_logger(__name__) | ||||||||
|
||||||||
# prepare a WAV byte to prevent repeatedly generating it | ||||||||
# Generate a 0.1 second silent audio file | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
# This will be used for the /v1/audio/transcriptions endpoint | ||||||||
_SILENT_WAV_BYTES = None | ||||||||
with io.BytesIO() as wav_buffer: | ||||||||
with wave.open(wav_buffer, "wb") as wf: | ||||||||
wf.setnchannels(1) # mono audio channel, standard configuration | ||||||||
wf.setsampwidth(2) # 16 bit audio, common bit depth for wav file | ||||||||
wf.setframerate(16000) # 16 kHz sample rate | ||||||||
wf.writeframes(b"\x00\x00" * 1600) # 0.1 second of silence | ||||||||
|
||||||||
# retrieves the generated wav bytes, return | ||||||||
_SILENT_WAV_BYTES = wav_buffer.getvalue() | ||||||||
logger.debug( | ||||||||
"======A default silent WAV file has been stored in memory within py application process====" | ||||||||
) | ||||||||
|
||||||||
|
||||||||
class SingletonMeta(type): | ||||||||
_instances = {} | ||||||||
|
@@ -52,6 +71,7 @@ class ModelType(enum.Enum): | |||||||
embeddings = "/v1/embeddings" | ||||||||
rerank = "/v1/rerank" | ||||||||
score = "/v1/score" | ||||||||
transcription = "/v1/audio/transcriptions" | ||||||||
|
||||||||
@staticmethod | ||||||||
def get_test_payload(model_type: str): | ||||||||
|
@@ -75,6 +95,12 @@ def get_test_payload(model_type: str): | |||||||
return {"query": "Hello", "documents": ["Test"]} | ||||||||
case ModelType.score: | ||||||||
return {"encoding_format": "float", "text_1": "Test", "test_2": "Test2"} | ||||||||
case ModelType.transcription: | ||||||||
if _SILENT_WAV_BYTES is not None: | ||||||||
logger.debug("=====Silent WAV Bytes is being used=====") | ||||||||
return { | ||||||||
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"), | ||||||||
} | ||||||||
|
||||||||
@staticmethod | ||||||||
def get_all_fields(): | ||||||||
|
@@ -161,14 +187,37 @@ def update_content_length(request: Request, request_body: str): | |||||||
|
||||||||
def is_model_healthy(url: str, model: str, model_type: str) -> bool: | ||||||||
model_details = ModelType[model_type] | ||||||||
|
||||||||
try: | ||||||||
response = requests.post( | ||||||||
f"{url}{model_details.value}", | ||||||||
headers={"Content-Type": "application/json"}, | ||||||||
json={"model": model} | model_details.get_test_payload(model_type), | ||||||||
timeout=30, | ||||||||
) | ||||||||
except Exception as e: | ||||||||
logger.error(e) | ||||||||
if model_type == "transcription": | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
# for transcription, the backend expects multipart/form-data with a file | ||||||||
# we will use pre-generated silent wav bytes | ||||||||
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")} | ||||||||
data = {"model": model} | ||||||||
response = requests.post( | ||||||||
f"{url}{model_details.value}", | ||||||||
files=files, # multipart/form-data | ||||||||
data=data, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should add a timeout here, otherwise the healthcheck thread will hang if the model hangs |
||||||||
timeout=10, | ||||||||
) | ||||||||
else: | ||||||||
# for other model types (chat, completion, etc.) | ||||||||
response = requests.post( | ||||||||
f"{url}{model_details.value}", | ||||||||
headers={"Content-Type": "application/json"}, | ||||||||
json={"model": model} | model_details.get_test_payload(model_type), | ||||||||
timeout=10, | ||||||||
) | ||||||||
|
||||||||
response.raise_for_status() | ||||||||
|
||||||||
if model_type == "transcription": | ||||||||
return True | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as far as I can see, the transcription model also returns valid json, no? But as mentioned below, its probably enough to just go for the status code check |
||||||||
else: | ||||||||
response.json() # verify it's valid json for other model types | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should return a boolean somewhere for the other models. Where is that done? Before we had There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't that be enough? |
||||||||
return True # validation passed | ||||||||
|
||||||||
except requests.exceptions.RequestException as e: | ||||||||
logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}") | ||||||||
return False | ||||||||
return response.status_code == 200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to align the parameter with the other endpoints? And also perhaps adding request-id would be a good one.