Skip to content

Commit fffa141

Browse files
gdhananiaorangetin
andauthored
Eng 38484/speaker diarization tests (#362)
* added speaker diarization tests * changed to main model * fixed top-level words * formatting --------- Co-authored-by: orangetin <[email protected]>
1 parent 76fc949 commit fffa141

File tree

3 files changed

+163
-4
lines changed

3 files changed

+163
-4
lines changed

src/together/resources/audio/transcriptions.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@ def create(
104104
)
105105

106106
# Add any additional kwargs
107-
params_data.update(kwargs)
107+
# Convert boolean values to lowercase strings for proper form encoding
108+
for key, value in kwargs.items():
109+
if isinstance(value, bool):
110+
params_data[key] = str(value).lower()
111+
else:
112+
params_data[key] = value
108113

109114
try:
110115
response, _, _ = requestor.request(
@@ -131,7 +136,8 @@ def create(
131136
response_format == "verbose_json"
132137
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
133138
):
134-
return AudioTranscriptionVerboseResponse(**response.data)
139+
# Create response with model validation that preserves extra fields
140+
return AudioTranscriptionVerboseResponse.model_validate(response.data)
135141
else:
136142
return AudioTranscriptionResponse(**response.data)
137143

@@ -234,7 +240,12 @@ async def create(
234240
)
235241

236242
# Add any additional kwargs
237-
params_data.update(kwargs)
243+
# Convert boolean values to lowercase strings for proper form encoding
244+
for key, value in kwargs.items():
245+
if isinstance(value, bool):
246+
params_data[key] = str(value).lower()
247+
else:
248+
params_data[key] = value
238249

239250
try:
240251
response, _, _ = await requestor.arequest(
@@ -261,6 +272,7 @@ async def create(
261272
response_format == "verbose_json"
262273
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
263274
):
264-
return AudioTranscriptionVerboseResponse(**response.data)
275+
# Create response with model validation that preserves extra fields
276+
return AudioTranscriptionVerboseResponse.model_validate(response.data)
265277
else:
266278
return AudioTranscriptionResponse(**response.data)

src/together/types/audio_speech.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,31 @@ class AudioTranscriptionWord(BaseModel):
158158
word: str
159159
start: float
160160
end: float
161+
id: Optional[int] = None
162+
speaker_id: Optional[str] = None
163+
164+
165+
class AudioSpeakerSegment(BaseModel):
166+
id: int
167+
speaker_id: str
168+
start: float
169+
end: float
170+
text: str
171+
words: List[AudioTranscriptionWord]
161172

162173

163174
class AudioTranscriptionResponse(BaseModel):
164175
text: str
165176

166177

167178
class AudioTranscriptionVerboseResponse(BaseModel):
179+
id: Optional[str] = None
168180
language: Optional[str] = None
169181
duration: Optional[float] = None
170182
text: str
171183
segments: Optional[List[AudioTranscriptionSegment]] = None
172184
words: Optional[List[AudioTranscriptionWord]] = None
185+
speaker_segments: Optional[List[AudioSpeakerSegment]] = None
173186

174187

175188
class AudioTranslationResponse(BaseModel):

tests/integration/resources/test_transcriptions.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,47 @@
99
)
1010

1111

12+
def validate_diarization_response(response_dict):
13+
"""
14+
Helper function to validate diarization response structure
15+
"""
16+
# Validate top-level speaker_segments field
17+
assert "speaker_segments" in response_dict
18+
assert isinstance(response_dict["speaker_segments"], list)
19+
assert len(response_dict["speaker_segments"]) > 0
20+
21+
# Validate each speaker segment structure
22+
for segment in response_dict["speaker_segments"]:
23+
assert "text" in segment
24+
assert "id" in segment
25+
assert "speaker_id" in segment
26+
assert "start" in segment
27+
assert "end" in segment
28+
assert "words" in segment
29+
30+
# Validate nested words in speaker segments
31+
assert isinstance(segment["words"], list)
32+
for word in segment["words"]:
33+
assert "id" in word
34+
assert "word" in word
35+
assert "start" in word
36+
assert "end" in word
37+
assert "speaker_id" in word
38+
39+
# Validate top-level words field
40+
assert "words" in response_dict
41+
assert isinstance(response_dict["words"], list)
42+
assert len(response_dict["words"]) > 0
43+
44+
# Validate each word in top-level words
45+
for word in response_dict["words"]:
46+
assert "id" in word
47+
assert "word" in word
48+
assert "start" in word
49+
assert "end" in word
50+
assert "speaker_id" in word
51+
52+
1253
class TestTogetherTranscriptions:
1354
@pytest.fixture
1455
def sync_together_client(self) -> Together:
@@ -116,3 +157,96 @@ def test_language_detection_hindi(self, sync_together_client):
116157
assert len(response.text) > 0
117158
assert hasattr(response, "language")
118159
assert response.language == "hi"
160+
161+
def test_diarization_default(self, sync_together_client):
162+
"""
163+
Test diarization with default model in verbose JSON format
164+
"""
165+
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"
166+
167+
response = sync_together_client.audio.transcriptions.create(
168+
file=audio_url,
169+
model="openai/whisper-large-v3",
170+
response_format="verbose_json",
171+
diarize=True,
172+
)
173+
174+
assert isinstance(response, AudioTranscriptionVerboseResponse)
175+
assert isinstance(response.text, str)
176+
assert len(response.text) > 0
177+
178+
# Validate diarization fields
179+
response_dict = response.model_dump()
180+
validate_diarization_response(response_dict)
181+
182+
def test_diarization_nvidia(self, sync_together_client):
183+
"""
184+
Test diarization with nvidia model in verbose JSON format
185+
"""
186+
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"
187+
188+
response = sync_together_client.audio.transcriptions.create(
189+
file=audio_url,
190+
model="openai/whisper-large-v3",
191+
response_format="verbose_json",
192+
diarize=True,
193+
diarization_model="nvidia",
194+
)
195+
196+
assert isinstance(response, AudioTranscriptionVerboseResponse)
197+
assert isinstance(response.text, str)
198+
assert len(response.text) > 0
199+
200+
# Validate diarization fields
201+
response_dict = response.model_dump()
202+
validate_diarization_response(response_dict)
203+
204+
def test_diarization_pyannote(self, sync_together_client):
205+
"""
206+
Test diarization with pyannote model in verbose JSON format
207+
"""
208+
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"
209+
210+
response = sync_together_client.audio.transcriptions.create(
211+
file=audio_url,
212+
model="openai/whisper-large-v3",
213+
response_format="verbose_json",
214+
diarize=True,
215+
diarization_model="pyannote",
216+
)
217+
218+
assert isinstance(response, AudioTranscriptionVerboseResponse)
219+
assert isinstance(response.text, str)
220+
assert len(response.text) > 0
221+
222+
# Validate diarization fields
223+
response_dict = response.model_dump()
224+
validate_diarization_response(response_dict)
225+
226+
def test_no_diarization(self, sync_together_client):
227+
"""
228+
Test with diarize=false should not have speaker segments
229+
"""
230+
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"
231+
232+
response = sync_together_client.audio.transcriptions.create(
233+
file=audio_url,
234+
model="openai/whisper-large-v3",
235+
response_format="verbose_json",
236+
diarize=False,
237+
)
238+
239+
assert isinstance(response, AudioTranscriptionVerboseResponse)
240+
assert isinstance(response.text, str)
241+
assert len(response.text) > 0
242+
243+
# Verify no diarization fields
244+
response_dict = response.model_dump()
245+
assert response_dict.get("speaker_segments") is None
246+
assert response_dict.get("words") is None
247+
248+
# Should still have standard fields
249+
assert "text" in response_dict
250+
assert "language" in response_dict
251+
assert "duration" in response_dict
252+
assert "segments" in response_dict

0 commit comments

Comments
 (0)