Skip to content

Commit 3c53805

Browse files
[Whisper] Add segment-level timestamp support (verbose_json)
- Accept `timestamp_granularities[]` and `response_format=verbose_json` in the `/v1/audio/transcriptions` endpoint - Switch decoder prompt from `<|notimestamps|>` to `<|0.00|>` when timestamps are requested so the model emits timestamp tokens - Parse timestamp tokens from output_ids into segments with start/end times in the serving layer - Add TranscriptionSegment and TranscriptionVerboseResponse protocol models matching the OpenAI API spec - Backward compatible: default behavior (json/text) unchanged
1 parent 721733c commit 3c53805

File tree

4 files changed

+162
-21
lines changed

4 files changed

+162
-21
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,11 +1480,18 @@ async def openai_v1_audio_transcriptions(
14801480
response_format: str = Form(default="json"),
14811481
temperature: float = Form(default=0.0),
14821482
stream: bool = Form(default=False),
1483+
timestamp_granularities: Optional[List[str]] = Form(
1484+
default=None, alias="timestamp_granularities[]"
1485+
),
14831486
):
14841487
"""OpenAI-compatible audio transcription endpoint."""
1485-
if response_format not in ["json", "text"]:
1488+
if response_format not in ["json", "text", "verbose_json"]:
14861489
return ORJSONResponse(
1487-
content={"error": {"message": "Only 'json' and 'text' formats supported"}},
1490+
content={
1491+
"error": {
1492+
"message": "Only 'json', 'text', and 'verbose_json' formats supported"
1493+
}
1494+
},
14881495
status_code=400,
14891496
)
14901497

@@ -1498,6 +1505,7 @@ async def openai_v1_audio_transcriptions(
14981505
response_format=response_format,
14991506
temperature=temperature,
15001507
stream=stream,
1508+
timestamp_granularities=timestamp_granularities,
15011509
raw_request=raw_request,
15021510
)
15031511
)

python/sglang/srt/entrypoints/openai/protocol.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,7 @@ class TranscriptionRequest(BaseModel):
14431443
language: Optional[str] = None
14441444
response_format: str = "json"
14451445
temperature: float = 0.0
1446+
timestamp_granularities: Optional[List[str]] = None
14461447
stream: bool = False
14471448
# Internal fields (not from API)
14481449
audio_data: Optional[bytes] = None
@@ -1463,6 +1464,26 @@ class TranscriptionResponse(BaseModel):
14631464
usage: Optional[TranscriptionUsage] = None
14641465

14651466

1467+
class TranscriptionSegment(BaseModel):
1468+
"""A segment with timestamp information."""
1469+
1470+
id: int
1471+
start: float
1472+
end: float
1473+
text: str
1474+
1475+
1476+
class TranscriptionVerboseResponse(BaseModel):
1477+
"""Verbose transcription response with timestamps (OpenAI-compatible)."""
1478+
1479+
task: str = "transcribe"
1480+
language: Optional[str] = None
1481+
duration: Optional[float] = None
1482+
text: str
1483+
segments: List[TranscriptionSegment] = []
1484+
usage: Optional[TranscriptionUsage] = None
1485+
1486+
14661487
class TranscriptionStreamChoice(BaseModel):
14671488
"""Delta content for streaming transcription."""
14681489

python/sglang/srt/entrypoints/openai/serving_transcription.py

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import math
2323
import time
2424
import uuid
25-
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union
25+
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union
2626

2727
from fastapi import Request
2828
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -32,9 +32,11 @@
3232
ErrorResponse,
3333
TranscriptionRequest,
3434
TranscriptionResponse,
35+
TranscriptionSegment,
3536
TranscriptionStreamChoice,
3637
TranscriptionStreamResponse,
3738
TranscriptionUsage,
39+
TranscriptionVerboseResponse,
3840
)
3941
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
4042
from sglang.srt.managers.io_struct import GenerateReqInput
@@ -44,6 +46,10 @@
4446

4547
logger = logging.getLogger(__name__)
4648

49+
# Whisper timestamp token constants
50+
TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|>
51+
TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds
52+
4753

4854
class OpenAIServingTranscription(OpenAIServingBase):
4955
"""Handler for /v1/audio/transcriptions requests"""
@@ -72,6 +78,9 @@ def _convert_to_internal_request(
7278
"language": request.language, # Pass to WhisperProcessor for language-specific decoding
7379
}
7480

81+
if request.timestamp_granularities:
82+
sampling_params["timestamp_granularities"] = request.timestamp_granularities
83+
7584
# For Whisper, we pass audio_data and let the processor handle it
7685
adapted_request = GenerateReqInput(
7786
text="", # Empty text - Whisper processor will set proper decoder tokens
@@ -89,13 +98,83 @@ def _get_audio_duration(self, audio_data: bytes) -> float:
8998
try:
9099
import soundfile as sf
91100

92-
audio_array, sr = sf.read(io.BytesIO(audio_data))
93-
duration = len(audio_array) / sr
94-
return duration
101+
info = sf.info(io.BytesIO(audio_data))
102+
return info.duration
95103
except Exception as e:
96104
logger.warning(f"Could not calculate audio duration: {e}")
97105
return 0.0
98106

107+
def _parse_segments(
108+
self, output_ids: List[int], tokenizer
109+
) -> tuple[str, List[TranscriptionSegment]]:
110+
"""Parse timestamp tokens from output_ids into segments.
111+
112+
The decoder prompt ends with <|0.00|>, so the first segment starts at
113+
t=0. The model then outputs:
114+
text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...]
115+
Each timestamp token marks the end of the current segment; its value
116+
also becomes the start of the next segment.
117+
"""
118+
# Token IDs for special tokens we want to strip from segment text
119+
eos_token_id = getattr(tokenizer, "eos_token_id", 50257)
120+
121+
segments = []
122+
full_text_parts = []
123+
current_text_tokens = []
124+
current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>)
125+
seg_id = 0
126+
127+
for token_id in output_ids:
128+
if token_id >= TIMESTAMP_BASE_TOKEN_ID:
129+
# This is a timestamp token — marks the end of current segment
130+
timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET
131+
132+
if current_text_tokens:
133+
text = tokenizer.decode(
134+
current_text_tokens, skip_special_tokens=True
135+
).strip()
136+
if text:
137+
segments.append(
138+
TranscriptionSegment(
139+
id=seg_id,
140+
start=round(current_start, 2),
141+
end=round(timestamp, 2),
142+
text=text,
143+
)
144+
)
145+
full_text_parts.append(text)
146+
seg_id += 1
147+
current_text_tokens = []
148+
149+
# Next segment starts at this timestamp
150+
current_start = timestamp
151+
152+
elif token_id == eos_token_id:
153+
# Skip end-of-text token
154+
continue
155+
else:
156+
# Regular text token
157+
current_text_tokens.append(token_id)
158+
159+
# Handle any trailing text tokens without a closing timestamp
160+
if current_text_tokens:
161+
text = tokenizer.decode(
162+
current_text_tokens, skip_special_tokens=True
163+
).strip()
164+
if text:
165+
segments.append(
166+
TranscriptionSegment(
167+
id=seg_id,
168+
start=round(current_start, 2),
169+
end=round(current_start, 2),
170+
text=text,
171+
)
172+
)
173+
full_text_parts.append(text)
174+
175+
full_text = " ".join(full_text_parts)
176+
return full_text, segments
177+
99178
async def create_transcription(
100179
self,
101180
audio_data: bytes,
@@ -105,7 +184,14 @@ async def create_transcription(
105184
temperature: float,
106185
stream: bool,
107186
raw_request: Request,
108-
) -> Union[TranscriptionResponse, StreamingResponse, Response, ORJSONResponse]:
187+
timestamp_granularities: Optional[List[str]] = None,
188+
) -> Union[
189+
TranscriptionResponse,
190+
TranscriptionVerboseResponse,
191+
StreamingResponse,
192+
Response,
193+
ORJSONResponse,
194+
]:
109195
"""Main entry point for transcription requests."""
110196
# Calculate audio duration for usage reporting
111197
audio_duration_s = self._get_audio_duration(audio_data)
@@ -117,6 +203,7 @@ async def create_transcription(
117203
language=language,
118204
response_format=response_format,
119205
temperature=temperature,
206+
timestamp_granularities=timestamp_granularities,
120207
stream=stream,
121208
audio_duration_s=audio_duration_s,
122209
)
@@ -129,7 +216,13 @@ async def _handle_non_streaming_request(
129216
adapted_request: GenerateReqInput,
130217
request: TranscriptionRequest,
131218
raw_request: Request,
132-
) -> Union[TranscriptionResponse, ErrorResponse, ORJSONResponse, Response]:
219+
) -> Union[
220+
TranscriptionResponse,
221+
TranscriptionVerboseResponse,
222+
ErrorResponse,
223+
ORJSONResponse,
224+
Response,
225+
]:
133226
"""Handle non-streaming transcription request."""
134227
try:
135228
ret = await self.tokenizer_manager.generate_request(
@@ -139,14 +232,26 @@ async def _handle_non_streaming_request(
139232
return self.create_error_response(str(e))
140233

141234
text = ret.get("text", "")
235+
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))
142236

143237
# Build response based on format
144238
if request.response_format == "text":
145239
return Response(content=text, media_type="text/plain")
146240

147-
# JSON format
148-
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))
241+
if request.response_format == "verbose_json":
242+
output_ids = ret.get("output_ids", [])
243+
tokenizer = self.tokenizer_manager.tokenizer
244+
parsed_text, segments = self._parse_segments(output_ids, tokenizer)
245+
246+
return TranscriptionVerboseResponse(
247+
language=request.language or "en",
248+
duration=round(request.audio_duration_s, 2),
249+
text=parsed_text or text,
250+
segments=segments,
251+
usage=usage,
252+
)
149253

254+
# Default JSON format
150255
return TranscriptionResponse(text=text, usage=usage)
151256

152257
async def _handle_streaming_request(

python/sglang/srt/multimodal/processors/whisper.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,9 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
115115
# Cache tokenizer for language token lookup
116116
self._tokenizer = getattr(self._processor, "tokenizer", None)
117117

118-
def _extract_language_from_request(self, request_obj) -> Optional[str]:
118+
def _pop_sampling_param(self, request_obj, key: str):
119119
sampling_params = getattr(request_obj, "sampling_params", None) or {}
120-
language = sampling_params.pop("language", None)
121-
return normalize_language_to_code(language)
120+
return sampling_params.pop(key, None)
122121

123122
def _get_language_token_id(self, language: Optional[str]) -> int:
124123
# Default to English if not specified
@@ -148,27 +147,35 @@ async def process_mm_data_async(
148147
# For Whisper, ALWAYS use the proper transcription token sequence
149148
# and IGNORE any text prompt - Whisper is a pure speech-to-text model
150149
# The decoder_start_token_id and forced_decoder_ids from generation config
151-
# set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|>]
150+
# set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|> or <|0.00|>]
152151

153-
# Extract language from request and get token ID
154-
language = self._extract_language_from_request(request_obj)
152+
language = normalize_language_to_code(
153+
self._pop_sampling_param(request_obj, "language")
154+
)
155155
language_token_id = self._get_language_token_id(language)
156+
timestamp_granularities = self._pop_sampling_param(
157+
request_obj, "timestamp_granularities"
158+
)
156159

157160
# Build decoder input tokens
158-
# <|startoftranscript|> + <|lang|> + <|transcribe|> + <|notimestamps|>
159161
decoder_start_token_id = getattr(
160162
self.hf_config, "decoder_start_token_id", 50258
161163
)
162164
transcribe_token_id = self._tokenizer.convert_tokens_to_ids("<|transcribe|>")
163-
notimestamps_token_id = self._tokenizer.convert_tokens_to_ids(
164-
"<|notimestamps|>"
165-
)
165+
166+
# Use <|0.00|> to enable timestamp generation, or <|notimestamps|> to disable
167+
if timestamp_granularities:
168+
timestamp_token_id = self._tokenizer.convert_tokens_to_ids("<|0.00|>")
169+
else:
170+
timestamp_token_id = self._tokenizer.convert_tokens_to_ids(
171+
"<|notimestamps|>"
172+
)
166173

167174
input_ids = [
168175
decoder_start_token_id,
169176
language_token_id,
170177
transcribe_token_id,
171-
notimestamps_token_id,
178+
timestamp_token_id,
172179
]
173180

174181
# Whisper expects input features padded to max_length (3000 frames = 30 seconds)

0 commit comments

Comments
 (0)