Skip to content

Commit 2c7297d

Browse files
[Whisper] Add segment-level timestamp support (verbose_json)
- Accept `timestamp_granularities[]` and `response_format=verbose_json` in the `/v1/audio/transcriptions` endpoint - Switch decoder prompt from `<|notimestamps|>` to `<|0.00|>` when timestamps are requested so the model emits timestamp tokens - Parse timestamp tokens from output_ids into segments with start/end times in the serving layer - Add TranscriptionSegment and TranscriptionVerboseResponse protocol models matching the OpenAI API spec - Backward compatible: default behavior (json/text) unchanged
1 parent 721733c commit 2c7297d

File tree

4 files changed

+160
-14
lines changed

4 files changed

+160
-14
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,11 +1480,18 @@ async def openai_v1_audio_transcriptions(
14801480
response_format: str = Form(default="json"),
14811481
temperature: float = Form(default=0.0),
14821482
stream: bool = Form(default=False),
1483+
timestamp_granularities: Optional[List[str]] = Form(
1484+
default=None, alias="timestamp_granularities[]"
1485+
),
14831486
):
14841487
"""OpenAI-compatible audio transcription endpoint."""
1485-
if response_format not in ["json", "text"]:
1488+
if response_format not in ["json", "text", "verbose_json"]:
14861489
return ORJSONResponse(
1487-
content={"error": {"message": "Only 'json' and 'text' formats supported"}},
1490+
content={
1491+
"error": {
1492+
"message": "Only 'json', 'text', and 'verbose_json' formats supported"
1493+
}
1494+
},
14881495
status_code=400,
14891496
)
14901497

@@ -1498,6 +1505,7 @@ async def openai_v1_audio_transcriptions(
14981505
response_format=response_format,
14991506
temperature=temperature,
15001507
stream=stream,
1508+
timestamp_granularities=timestamp_granularities,
15011509
raw_request=raw_request,
15021510
)
15031511
)

python/sglang/srt/entrypoints/openai/protocol.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,7 @@ class TranscriptionRequest(BaseModel):
14431443
language: Optional[str] = None
14441444
response_format: str = "json"
14451445
temperature: float = 0.0
1446+
timestamp_granularities: Optional[List[str]] = None
14461447
stream: bool = False
14471448
# Internal fields (not from API)
14481449
audio_data: Optional[bytes] = None
@@ -1463,6 +1464,26 @@ class TranscriptionResponse(BaseModel):
14631464
usage: Optional[TranscriptionUsage] = None
14641465

14651466

1467+
class TranscriptionSegment(BaseModel):
1468+
"""A segment with timestamp information."""
1469+
1470+
id: int
1471+
start: float
1472+
end: float
1473+
text: str
1474+
1475+
1476+
class TranscriptionVerboseResponse(BaseModel):
1477+
"""Verbose transcription response with timestamps (OpenAI-compatible)."""
1478+
1479+
task: str = "transcribe"
1480+
language: Optional[str] = None
1481+
duration: Optional[float] = None
1482+
text: str
1483+
segments: List[TranscriptionSegment] = []
1484+
usage: Optional[TranscriptionUsage] = None
1485+
1486+
14661487
class TranscriptionStreamChoice(BaseModel):
14671488
"""Delta content for streaming transcription."""
14681489

python/sglang/srt/entrypoints/openai/serving_transcription.py

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import math
2323
import time
2424
import uuid
25-
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union
25+
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union
2626

2727
from fastapi import Request
2828
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -32,9 +32,11 @@
3232
ErrorResponse,
3333
TranscriptionRequest,
3434
TranscriptionResponse,
35+
TranscriptionSegment,
3536
TranscriptionStreamChoice,
3637
TranscriptionStreamResponse,
3738
TranscriptionUsage,
39+
TranscriptionVerboseResponse,
3840
)
3941
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
4042
from sglang.srt.managers.io_struct import GenerateReqInput
@@ -44,6 +46,10 @@
4446

4547
logger = logging.getLogger(__name__)
4648

49+
# Whisper timestamp token constants
50+
TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|>
51+
TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds
52+
4753

4854
class OpenAIServingTranscription(OpenAIServingBase):
4955
"""Handler for /v1/audio/transcriptions requests"""
@@ -72,6 +78,11 @@ def _convert_to_internal_request(
7278
"language": request.language, # Pass to WhisperProcessor for language-specific decoding
7379
}
7480

81+
# Pass timestamp_granularities to WhisperProcessor so it can
82+
# switch from <|notimestamps|> to <|0.00|> in the decoder prompt
83+
if request.timestamp_granularities:
84+
sampling_params["timestamp_granularities"] = request.timestamp_granularities
85+
7586
# For Whisper, we pass audio_data and let the processor handle it
7687
adapted_request = GenerateReqInput(
7788
text="", # Empty text - Whisper processor will set proper decoder tokens
@@ -96,6 +107,77 @@ def _get_audio_duration(self, audio_data: bytes) -> float:
96107
logger.warning(f"Could not calculate audio duration: {e}")
97108
return 0.0
98109

110+
def _parse_segments(
111+
self, output_ids: List[int], tokenizer
112+
) -> tuple[str, List[TranscriptionSegment]]:
113+
"""Parse timestamp tokens from output_ids into segments.
114+
115+
The decoder prompt ends with <|0.00|>, so the first segment starts at
116+
t=0. The model then outputs:
117+
text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...]
118+
Each timestamp token marks the end of the current segment; its value
119+
also becomes the start of the next segment.
120+
"""
121+
# Token IDs for special tokens we want to strip from segment text
122+
eos_token_id = getattr(tokenizer, "eos_token_id", 50257)
123+
124+
segments = []
125+
full_text_parts = []
126+
current_text_tokens = []
127+
current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>)
128+
seg_id = 0
129+
130+
for token_id in output_ids:
131+
if token_id >= TIMESTAMP_BASE_TOKEN_ID:
132+
# This is a timestamp token — marks the end of current segment
133+
timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET
134+
135+
if current_text_tokens:
136+
text = tokenizer.decode(
137+
current_text_tokens, skip_special_tokens=True
138+
).strip()
139+
if text:
140+
segments.append(
141+
TranscriptionSegment(
142+
id=seg_id,
143+
start=round(current_start, 2),
144+
end=round(timestamp, 2),
145+
text=text,
146+
)
147+
)
148+
full_text_parts.append(text)
149+
seg_id += 1
150+
current_text_tokens = []
151+
152+
# Next segment starts at this timestamp
153+
current_start = timestamp
154+
155+
elif token_id == eos_token_id:
156+
# Skip end-of-text token
157+
continue
158+
else:
159+
# Regular text token
160+
current_text_tokens.append(token_id)
161+
162+
# Handle any trailing text tokens without a closing timestamp
163+
if current_text_tokens:
164+
text = tokenizer.decode(
165+
current_text_tokens, skip_special_tokens=True
166+
).strip()
167+
if text:
168+
segments.append(
169+
TranscriptionSegment(
170+
id=seg_id,
171+
start=round(current_start, 2),
172+
end=round(current_start, 2),
173+
text=text,
174+
)
175+
)
176+
full_text_parts.append(text)
177+
178+
full_text = " ".join(full_text_parts)
179+
return full_text, segments
180+
99181
async def create_transcription(
100182
self,
101183
audio_data: bytes,
@@ -105,7 +187,14 @@ async def create_transcription(
105187
temperature: float,
106188
stream: bool,
107189
raw_request: Request,
108-
) -> Union[TranscriptionResponse, StreamingResponse, Response, ORJSONResponse]:
190+
timestamp_granularities: Optional[List[str]] = None,
191+
) -> Union[
192+
TranscriptionResponse,
193+
TranscriptionVerboseResponse,
194+
StreamingResponse,
195+
Response,
196+
ORJSONResponse,
197+
]:
109198
"""Main entry point for transcription requests."""
110199
# Calculate audio duration for usage reporting
111200
audio_duration_s = self._get_audio_duration(audio_data)
@@ -117,6 +206,7 @@ async def create_transcription(
117206
language=language,
118207
response_format=response_format,
119208
temperature=temperature,
209+
timestamp_granularities=timestamp_granularities,
120210
stream=stream,
121211
audio_duration_s=audio_duration_s,
122212
)
@@ -129,7 +219,13 @@ async def _handle_non_streaming_request(
129219
adapted_request: GenerateReqInput,
130220
request: TranscriptionRequest,
131221
raw_request: Request,
132-
) -> Union[TranscriptionResponse, ErrorResponse, ORJSONResponse, Response]:
222+
) -> Union[
223+
TranscriptionResponse,
224+
TranscriptionVerboseResponse,
225+
ErrorResponse,
226+
ORJSONResponse,
227+
Response,
228+
]:
133229
"""Handle non-streaming transcription request."""
134230
try:
135231
ret = await self.tokenizer_manager.generate_request(
@@ -139,14 +235,26 @@ async def _handle_non_streaming_request(
139235
return self.create_error_response(str(e))
140236

141237
text = ret.get("text", "")
238+
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))
142239

143240
# Build response based on format
144241
if request.response_format == "text":
145242
return Response(content=text, media_type="text/plain")
146243

147-
# JSON format
148-
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))
244+
if request.response_format == "verbose_json":
245+
output_ids = ret.get("output_ids", [])
246+
tokenizer = self.tokenizer_manager.tokenizer
247+
parsed_text, segments = self._parse_segments(output_ids, tokenizer)
248+
249+
return TranscriptionVerboseResponse(
250+
language=request.language or "en",
251+
duration=round(request.audio_duration_s, 2),
252+
text=parsed_text or text,
253+
segments=segments,
254+
usage=usage,
255+
)
149256

257+
# Default JSON format
150258
return TranscriptionResponse(text=text, usage=usage)
151259

152260
async def _handle_streaming_request(

python/sglang/srt/multimodal/processors/whisper.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def _extract_language_from_request(self, request_obj) -> Optional[str]:
120120
language = sampling_params.pop("language", None)
121121
return normalize_language_to_code(language)
122122

123+
def _extract_timestamp_granularities(self, request_obj) -> Optional[list]:
124+
sampling_params = getattr(request_obj, "sampling_params", None) or {}
125+
return sampling_params.pop("timestamp_granularities", None)
126+
123127
def _get_language_token_id(self, language: Optional[str]) -> int:
124128
# Default to English if not specified
125129
if language is None:
@@ -148,27 +152,32 @@ async def process_mm_data_async(
148152
# For Whisper, ALWAYS use the proper transcription token sequence
149153
# and IGNORE any text prompt - Whisper is a pure speech-to-text model
150154
# The decoder_start_token_id and forced_decoder_ids from generation config
151-
# set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|>]
155+
# set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|> or <|0.00|>]
152156

153-
# Extract language from request and get token ID
157+
# Extract language and timestamp settings from request
154158
language = self._extract_language_from_request(request_obj)
155159
language_token_id = self._get_language_token_id(language)
160+
timestamp_granularities = self._extract_timestamp_granularities(request_obj)
156161

157162
# Build decoder input tokens
158-
# <|startoftranscript|> + <|lang|> + <|transcribe|> + <|notimestamps|>
159163
decoder_start_token_id = getattr(
160164
self.hf_config, "decoder_start_token_id", 50258
161165
)
162166
transcribe_token_id = self._tokenizer.convert_tokens_to_ids("<|transcribe|>")
163-
notimestamps_token_id = self._tokenizer.convert_tokens_to_ids(
164-
"<|notimestamps|>"
165-
)
167+
168+
# Use <|0.00|> to enable timestamp generation, or <|notimestamps|> to disable
169+
if timestamp_granularities:
170+
timestamp_token_id = self._tokenizer.convert_tokens_to_ids("<|0.00|>")
171+
else:
172+
timestamp_token_id = self._tokenizer.convert_tokens_to_ids(
173+
"<|notimestamps|>"
174+
)
166175

167176
input_ids = [
168177
decoder_start_token_id,
169178
language_token_id,
170179
transcribe_token_id,
171-
notimestamps_token_id,
180+
timestamp_token_id,
172181
]
173182

174183
# Whisper expects input features padded to max_length (3000 frames = 30 seconds)

0 commit comments

Comments
 (0)