2222import math
2323import time
2424import uuid
25- from typing import TYPE_CHECKING , AsyncGenerator , Optional , Union
25+ from typing import TYPE_CHECKING , AsyncGenerator , List , Optional , Union
2626
2727from fastapi import Request
2828from fastapi .responses import ORJSONResponse , Response , StreamingResponse
3232 ErrorResponse ,
3333 TranscriptionRequest ,
3434 TranscriptionResponse ,
35+ TranscriptionSegment ,
3536 TranscriptionStreamChoice ,
3637 TranscriptionStreamResponse ,
3738 TranscriptionUsage ,
39+ TranscriptionVerboseResponse ,
3840)
3941from sglang .srt .entrypoints .openai .serving_base import OpenAIServingBase
4042from sglang .srt .managers .io_struct import GenerateReqInput
4446
4547logger = logging .getLogger (__name__ )
4648
49+ # Whisper timestamp token constants
50+ TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|>
51+ TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds
52+
4753
4854class OpenAIServingTranscription (OpenAIServingBase ):
4955 """Handler for /v1/audio/transcriptions requests"""
@@ -72,6 +78,11 @@ def _convert_to_internal_request(
7278 "language" : request .language , # Pass to WhisperProcessor for language-specific decoding
7379 }
7480
81+ # Pass timestamp_granularities to WhisperProcessor so it can
82+ # switch from <|notimestamps|> to <|0.00|> in the decoder prompt
83+ if request .timestamp_granularities :
84+ sampling_params ["timestamp_granularities" ] = request .timestamp_granularities
85+
7586 # For Whisper, we pass audio_data and let the processor handle it
7687 adapted_request = GenerateReqInput (
7788 text = "" , # Empty text - Whisper processor will set proper decoder tokens
@@ -96,6 +107,77 @@ def _get_audio_duration(self, audio_data: bytes) -> float:
96107 logger .warning (f"Could not calculate audio duration: { e } " )
97108 return 0.0
98109
110+ def _parse_segments (
111+ self , output_ids : List [int ], tokenizer
112+ ) -> tuple [str , List [TranscriptionSegment ]]:
113+ """Parse timestamp tokens from output_ids into segments.
114+
115+ The decoder prompt ends with <|0.00|>, so the first segment starts at
116+ t=0. The model then outputs:
117+ text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...]
118+ Each timestamp token marks the end of the current segment; its value
119+ also becomes the start of the next segment.
120+ """
121+ # Token IDs for special tokens we want to strip from segment text
122+ eos_token_id = getattr (tokenizer , "eos_token_id" , 50257 )
123+
124+ segments = []
125+ full_text_parts = []
126+ current_text_tokens = []
127+ current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>)
128+ seg_id = 0
129+
130+ for token_id in output_ids :
131+ if token_id >= TIMESTAMP_BASE_TOKEN_ID :
132+ # This is a timestamp token — marks the end of current segment
133+ timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID ) * TIMESTAMP_BASE_OFFSET
134+
135+ if current_text_tokens :
136+ text = tokenizer .decode (
137+ current_text_tokens , skip_special_tokens = True
138+ ).strip ()
139+ if text :
140+ segments .append (
141+ TranscriptionSegment (
142+ id = seg_id ,
143+ start = round (current_start , 2 ),
144+ end = round (timestamp , 2 ),
145+ text = text ,
146+ )
147+ )
148+ full_text_parts .append (text )
149+ seg_id += 1
150+ current_text_tokens = []
151+
152+ # Next segment starts at this timestamp
153+ current_start = timestamp
154+
155+ elif token_id == eos_token_id :
156+ # Skip end-of-text token
157+ continue
158+ else :
159+ # Regular text token
160+ current_text_tokens .append (token_id )
161+
162+ # Handle any trailing text tokens without a closing timestamp
163+ if current_text_tokens :
164+ text = tokenizer .decode (
165+ current_text_tokens , skip_special_tokens = True
166+ ).strip ()
167+ if text :
168+ segments .append (
169+ TranscriptionSegment (
170+ id = seg_id ,
171+ start = round (current_start , 2 ),
172+ end = round (current_start , 2 ),
173+ text = text ,
174+ )
175+ )
176+ full_text_parts .append (text )
177+
178+ full_text = " " .join (full_text_parts )
179+ return full_text , segments
180+
99181 async def create_transcription (
100182 self ,
101183 audio_data : bytes ,
@@ -105,7 +187,14 @@ async def create_transcription(
105187 temperature : float ,
106188 stream : bool ,
107189 raw_request : Request ,
108- ) -> Union [TranscriptionResponse , StreamingResponse , Response , ORJSONResponse ]:
190+ timestamp_granularities : Optional [List [str ]] = None ,
191+ ) -> Union [
192+ TranscriptionResponse ,
193+ TranscriptionVerboseResponse ,
194+ StreamingResponse ,
195+ Response ,
196+ ORJSONResponse ,
197+ ]:
109198 """Main entry point for transcription requests."""
110199 # Calculate audio duration for usage reporting
111200 audio_duration_s = self ._get_audio_duration (audio_data )
@@ -117,6 +206,7 @@ async def create_transcription(
117206 language = language ,
118207 response_format = response_format ,
119208 temperature = temperature ,
209+ timestamp_granularities = timestamp_granularities ,
120210 stream = stream ,
121211 audio_duration_s = audio_duration_s ,
122212 )
@@ -129,7 +219,13 @@ async def _handle_non_streaming_request(
129219 adapted_request : GenerateReqInput ,
130220 request : TranscriptionRequest ,
131221 raw_request : Request ,
132- ) -> Union [TranscriptionResponse , ErrorResponse , ORJSONResponse , Response ]:
222+ ) -> Union [
223+ TranscriptionResponse ,
224+ TranscriptionVerboseResponse ,
225+ ErrorResponse ,
226+ ORJSONResponse ,
227+ Response ,
228+ ]:
133229 """Handle non-streaming transcription request."""
134230 try :
135231 ret = await self .tokenizer_manager .generate_request (
@@ -139,14 +235,26 @@ async def _handle_non_streaming_request(
139235 return self .create_error_response (str (e ))
140236
141237 text = ret .get ("text" , "" )
238+ usage = TranscriptionUsage (seconds = int (math .ceil (request .audio_duration_s )))
142239
143240 # Build response based on format
144241 if request .response_format == "text" :
145242 return Response (content = text , media_type = "text/plain" )
146243
147- # JSON format
148- usage = TranscriptionUsage (seconds = int (math .ceil (request .audio_duration_s )))
244+ if request .response_format == "verbose_json" :
245+ output_ids = ret .get ("output_ids" , [])
246+ tokenizer = self .tokenizer_manager .tokenizer
247+ parsed_text , segments = self ._parse_segments (output_ids , tokenizer )
248+
249+ return TranscriptionVerboseResponse (
250+ language = request .language or "en" ,
251+ duration = round (request .audio_duration_s , 2 ),
252+ text = parsed_text or text ,
253+ segments = segments ,
254+ usage = usage ,
255+ )
149256
257+ # Default JSON format
150258 return TranscriptionResponse (text = text , usage = usage )
151259
152260 async def _handle_streaming_request (
0 commit comments