2222import math
2323import time
2424import uuid
25- from typing import TYPE_CHECKING , AsyncGenerator , Optional , Union
25+ from typing import TYPE_CHECKING , AsyncGenerator , List , Optional , Union
2626
2727from fastapi import Request
2828from fastapi .responses import ORJSONResponse , Response , StreamingResponse
3232 ErrorResponse ,
3333 TranscriptionRequest ,
3434 TranscriptionResponse ,
35+ TranscriptionSegment ,
3536 TranscriptionStreamChoice ,
3637 TranscriptionStreamResponse ,
3738 TranscriptionUsage ,
39+ TranscriptionVerboseResponse ,
3840)
3941from sglang .srt .entrypoints .openai .serving_base import OpenAIServingBase
4042from sglang .srt .managers .io_struct import GenerateReqInput
4446
4547logger = logging .getLogger (__name__ )
4648
49+ # Whisper timestamp token constants
50+ TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|>
51+ TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds
52+
4753
4854class OpenAIServingTranscription (OpenAIServingBase ):
4955 """Handler for /v1/audio/transcriptions requests"""
@@ -72,6 +78,9 @@ def _convert_to_internal_request(
7278 "language" : request .language , # Pass to WhisperProcessor for language-specific decoding
7379 }
7480
81+ if request .timestamp_granularities :
82+ sampling_params ["timestamp_granularities" ] = request .timestamp_granularities
83+
7584 # For Whisper, we pass audio_data and let the processor handle it
7685 adapted_request = GenerateReqInput (
7786 text = "" , # Empty text - Whisper processor will set proper decoder tokens
@@ -89,13 +98,83 @@ def _get_audio_duration(self, audio_data: bytes) -> float:
8998 try :
9099 import soundfile as sf
91100
92- audio_array , sr = sf .read (io .BytesIO (audio_data ))
93- duration = len (audio_array ) / sr
94- return duration
101+ info = sf .info (io .BytesIO (audio_data ))
102+ return info .duration
95103 except Exception as e :
96104 logger .warning (f"Could not calculate audio duration: { e } " )
97105 return 0.0
98106
107+ def _parse_segments (
108+ self , output_ids : List [int ], tokenizer
109+ ) -> tuple [str , List [TranscriptionSegment ]]:
110+ """Parse timestamp tokens from output_ids into segments.
111+
112+ The decoder prompt ends with <|0.00|>, so the first segment starts at
113+ t=0. The model then outputs:
114+ text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...]
115+ Each timestamp token marks the end of the current segment; its value
116+ also becomes the start of the next segment.
117+ """
118+ # Token IDs for special tokens we want to strip from segment text
119+ eos_token_id = getattr (tokenizer , "eos_token_id" , 50257 )
120+
121+ segments = []
122+ full_text_parts = []
123+ current_text_tokens = []
124+ current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>)
125+ seg_id = 0
126+
127+ for token_id in output_ids :
128+ if token_id >= TIMESTAMP_BASE_TOKEN_ID :
129+ # This is a timestamp token — marks the end of current segment
130+ timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID ) * TIMESTAMP_BASE_OFFSET
131+
132+ if current_text_tokens :
133+ text = tokenizer .decode (
134+ current_text_tokens , skip_special_tokens = True
135+ ).strip ()
136+ if text :
137+ segments .append (
138+ TranscriptionSegment (
139+ id = seg_id ,
140+ start = round (current_start , 2 ),
141+ end = round (timestamp , 2 ),
142+ text = text ,
143+ )
144+ )
145+ full_text_parts .append (text )
146+ seg_id += 1
147+ current_text_tokens = []
148+
149+ # Next segment starts at this timestamp
150+ current_start = timestamp
151+
152+ elif token_id == eos_token_id :
153+ # Skip end-of-text token
154+ continue
155+ else :
156+ # Regular text token
157+ current_text_tokens .append (token_id )
158+
159+ # Handle any trailing text tokens without a closing timestamp
160+ if current_text_tokens :
161+ text = tokenizer .decode (
162+ current_text_tokens , skip_special_tokens = True
163+ ).strip ()
164+ if text :
165+ segments .append (
166+ TranscriptionSegment (
167+ id = seg_id ,
168+ start = round (current_start , 2 ),
169+ end = round (current_start , 2 ),
170+ text = text ,
171+ )
172+ )
173+ full_text_parts .append (text )
174+
175+ full_text = " " .join (full_text_parts )
176+ return full_text , segments
177+
99178 async def create_transcription (
100179 self ,
101180 audio_data : bytes ,
@@ -105,7 +184,14 @@ async def create_transcription(
105184 temperature : float ,
106185 stream : bool ,
107186 raw_request : Request ,
108- ) -> Union [TranscriptionResponse , StreamingResponse , Response , ORJSONResponse ]:
187+ timestamp_granularities : Optional [List [str ]] = None ,
188+ ) -> Union [
189+ TranscriptionResponse ,
190+ TranscriptionVerboseResponse ,
191+ StreamingResponse ,
192+ Response ,
193+ ORJSONResponse ,
194+ ]:
109195 """Main entry point for transcription requests."""
110196 # Calculate audio duration for usage reporting
111197 audio_duration_s = self ._get_audio_duration (audio_data )
@@ -117,6 +203,7 @@ async def create_transcription(
117203 language = language ,
118204 response_format = response_format ,
119205 temperature = temperature ,
206+ timestamp_granularities = timestamp_granularities ,
120207 stream = stream ,
121208 audio_duration_s = audio_duration_s ,
122209 )
@@ -129,7 +216,13 @@ async def _handle_non_streaming_request(
129216 adapted_request : GenerateReqInput ,
130217 request : TranscriptionRequest ,
131218 raw_request : Request ,
132- ) -> Union [TranscriptionResponse , ErrorResponse , ORJSONResponse , Response ]:
219+ ) -> Union [
220+ TranscriptionResponse ,
221+ TranscriptionVerboseResponse ,
222+ ErrorResponse ,
223+ ORJSONResponse ,
224+ Response ,
225+ ]:
133226 """Handle non-streaming transcription request."""
134227 try :
135228 ret = await self .tokenizer_manager .generate_request (
@@ -139,14 +232,26 @@ async def _handle_non_streaming_request(
139232 return self .create_error_response (str (e ))
140233
141234 text = ret .get ("text" , "" )
235+ usage = TranscriptionUsage (seconds = int (math .ceil (request .audio_duration_s )))
142236
143237 # Build response based on format
144238 if request .response_format == "text" :
145239 return Response (content = text , media_type = "text/plain" )
146240
147- # JSON format
148- usage = TranscriptionUsage (seconds = int (math .ceil (request .audio_duration_s )))
241+ if request .response_format == "verbose_json" :
242+ output_ids = ret .get ("output_ids" , [])
243+ tokenizer = self .tokenizer_manager .tokenizer
244+ parsed_text , segments = self ._parse_segments (output_ids , tokenizer )
245+
246+ return TranscriptionVerboseResponse (
247+ language = request .language or "en" ,
248+ duration = round (request .audio_duration_s , 2 ),
249+ text = parsed_text or text ,
250+ segments = segments ,
251+ usage = usage ,
252+ )
149253
254+ # Default JSON format
150255 return TranscriptionResponse (text = text , usage = usage )
151256
152257 async def _handle_streaming_request (
0 commit comments