Skip to content

Commit 8a87cd2

Browse files
authored
[CI] Speed up Whisper tests by reusing server (#22859)
Signed-off-by: mgoin <[email protected]>
1 parent a344a1a commit 8a87cd2

File tree

2 files changed

+263
-291
lines changed

2 files changed

+263
-291
lines changed

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 141 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@
44
# imports for guided decoding tests
55
import io
66
import json
7-
from unittest.mock import patch
87

98
import librosa
109
import numpy as np
1110
import openai
1211
import pytest
12+
import pytest_asyncio
1313
import soundfile as sf
14-
from openai._base_client import AsyncAPIClient
1514

1615
from vllm.assets.audio import AudioAsset
1716

1817
from ...utils import RemoteOpenAIServer
1918

19+
MODEL_NAME = "openai/whisper-large-v3-turbo"
20+
SERVER_ARGS = ["--enforce-eager"]
2021
MISTRAL_FORMAT_ARGS = [
2122
"--tokenizer_mode", "mistral", "--config_format", "mistral",
2223
"--load_format", "mistral"
@@ -37,6 +38,18 @@ def winning_call():
3738
yield f
3839

3940

41+
@pytest.fixture(scope="module")
42+
def server():
43+
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
44+
yield remote_server
45+
46+
47+
@pytest_asyncio.fixture
48+
async def client(server):
49+
async with server.get_async_client() as async_client:
50+
yield async_client
51+
52+
4053
@pytest.mark.asyncio
4154
@pytest.mark.parametrize(
4255
"model_name",
@@ -61,25 +74,33 @@ async def test_basic_audio(mary_had_lamb, model_name):
6174

6275

6376
@pytest.mark.asyncio
64-
async def test_bad_requests(mary_had_lamb):
65-
model_name = "openai/whisper-small"
66-
server_args = ["--enforce-eager"]
67-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
77+
async def test_non_asr_model(winning_call):
78+
# text to text model
79+
model_name = "JackFram/llama-68m"
80+
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
6881
client = remote_server.get_async_client()
69-
70-
# invalid language
71-
with pytest.raises(openai.BadRequestError):
72-
await client.audio.transcriptions.create(model=model_name,
73-
file=mary_had_lamb,
74-
language="hh",
75-
temperature=0.0)
82+
res = await client.audio.transcriptions.create(model=model_name,
83+
file=winning_call,
84+
language="en",
85+
temperature=0.0)
86+
err = res.error
87+
assert err["code"] == 400 and not res.text
88+
assert err[
89+
"message"] == "The model does not support Transcriptions API"
7690

7791

7892
@pytest.mark.asyncio
79-
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
80-
async def test_long_audio_request(mary_had_lamb, model_name):
81-
server_args = ["--enforce-eager"]
93+
async def test_bad_requests(mary_had_lamb, client):
94+
# invalid language
95+
with pytest.raises(openai.BadRequestError):
96+
await client.audio.transcriptions.create(model=MODEL_NAME,
97+
file=mary_had_lamb,
98+
language="hh",
99+
temperature=0.0)
100+
82101

102+
@pytest.mark.asyncio
103+
async def test_long_audio_request(mary_had_lamb, client):
83104
mary_had_lamb.seek(0)
84105
audio, sr = librosa.load(mary_had_lamb)
85106
# Add small silence after each audio for repeatability in the split process
@@ -89,188 +110,129 @@ async def test_long_audio_request(mary_had_lamb, model_name):
89110
buffer = io.BytesIO()
90111
sf.write(buffer, repeated_audio, sr, format='WAV')
91112
buffer.seek(0)
92-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
93-
client = remote_server.get_async_client()
94-
transcription = await client.audio.transcriptions.create(
95-
model=model_name,
96-
file=buffer,
97-
language="en",
98-
response_format="text",
99-
temperature=0.0)
100-
out = json.loads(transcription)['text']
101-
counts = out.count("Mary had a little lamb")
102-
assert counts == 10, counts
103-
104-
105-
@pytest.mark.asyncio
106-
async def test_non_asr_model(winning_call):
107-
# text to text model
108-
model_name = "JackFram/llama-68m"
109-
server_args = ["--enforce-eager"]
110-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
111-
client = remote_server.get_async_client()
112-
res = await client.audio.transcriptions.create(model=model_name,
113-
file=winning_call,
114-
language="en",
115-
temperature=0.0)
116-
err = res.error
117-
assert err["code"] == 400 and not res.text
118-
assert err[
119-
"message"] == "The model does not support Transcriptions API"
113+
transcription = await client.audio.transcriptions.create(
114+
model=MODEL_NAME,
115+
file=buffer,
116+
language="en",
117+
response_format="text",
118+
temperature=0.0)
119+
out = json.loads(transcription)['text']
120+
counts = out.count("Mary had a little lamb")
121+
assert counts == 10, counts
120122

121123

122124
@pytest.mark.asyncio
123-
async def test_completion_endpoints():
125+
async def test_completion_endpoints(client):
124126
# text to text model
125-
model_name = "openai/whisper-small"
126-
server_args = ["--enforce-eager"]
127-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
128-
client = remote_server.get_async_client()
129-
res = await client.chat.completions.create(
130-
model=model_name,
131-
messages=[{
132-
"role": "system",
133-
"content": "You are a helpful assistant."
134-
}])
135-
err = res.error
136-
assert err["code"] == 400
137-
assert err[
138-
"message"] == "The model does not support Chat Completions API"
139-
140-
res = await client.completions.create(model=model_name, prompt="Hello")
141-
err = res.error
142-
assert err["code"] == 400
143-
assert err["message"] == "The model does not support Completions API"
127+
res = await client.chat.completions.create(
128+
model=MODEL_NAME,
129+
messages=[{
130+
"role": "system",
131+
"content": "You are a helpful assistant."
132+
}])
133+
err = res.error
134+
assert err["code"] == 400
135+
assert err["message"] == "The model does not support Chat Completions API"
136+
137+
res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
138+
err = res.error
139+
assert err["code"] == 400
140+
assert err["message"] == "The model does not support Completions API"
144141

145142

146143
@pytest.mark.asyncio
147-
async def test_streaming_response(winning_call):
148-
model_name = "openai/whisper-small"
149-
server_args = ["--enforce-eager"]
144+
async def test_streaming_response(winning_call, client):
150145
transcription = ""
151-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
152-
client = remote_server.get_async_client()
153-
res_no_stream = await client.audio.transcriptions.create(
154-
model=model_name,
155-
file=winning_call,
156-
response_format="json",
157-
language="en",
158-
temperature=0.0)
159-
# Unfortunately this only works when the openai client is patched
160-
# to use streaming mode, not exposed in the transcription api.
161-
original_post = AsyncAPIClient.post
162-
163-
async def post_with_stream(*args, **kwargs):
164-
kwargs['stream'] = True
165-
return await original_post(*args, **kwargs)
166-
167-
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
168-
client = remote_server.get_async_client()
169-
res = await client.audio.transcriptions.create(
170-
model=model_name,
171-
file=winning_call,
172-
language="en",
173-
temperature=0.0,
174-
extra_body=dict(stream=True),
175-
timeout=30)
176-
# Reconstruct from chunks and validate
177-
async for chunk in res:
178-
# just a chunk
179-
text = chunk.choices[0]['delta']['content']
180-
transcription += text
181-
182-
assert transcription == res_no_stream.text
146+
res_no_stream = await client.audio.transcriptions.create(
147+
model=MODEL_NAME,
148+
file=winning_call,
149+
response_format="json",
150+
language="en",
151+
temperature=0.0)
152+
res = await client.audio.transcriptions.create(model=MODEL_NAME,
153+
file=winning_call,
154+
language="en",
155+
temperature=0.0,
156+
stream=True,
157+
timeout=30)
158+
# Reconstruct from chunks and validate
159+
async for chunk in res:
160+
text = chunk.choices[0]['delta']['content']
161+
transcription += text
162+
163+
assert transcription == res_no_stream.text
183164

184165

185166
@pytest.mark.asyncio
186-
async def test_stream_options(winning_call):
187-
model_name = "openai/whisper-small"
188-
server_args = ["--enforce-eager"]
189-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
190-
original_post = AsyncAPIClient.post
191-
192-
async def post_with_stream(*args, **kwargs):
193-
kwargs['stream'] = True
194-
return await original_post(*args, **kwargs)
195-
196-
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
197-
client = remote_server.get_async_client()
198-
res = await client.audio.transcriptions.create(
199-
model=model_name,
200-
file=winning_call,
201-
language="en",
202-
temperature=0.0,
203-
extra_body=dict(stream=True,
204-
stream_include_usage=True,
205-
stream_continuous_usage_stats=True),
206-
timeout=30)
207-
final = False
208-
continuous = True
209-
async for chunk in res:
210-
if not len(chunk.choices):
211-
# final usage sent
212-
final = True
213-
else:
214-
continuous = continuous and hasattr(chunk, 'usage')
215-
assert final and continuous
167+
async def test_stream_options(winning_call, client):
168+
res = await client.audio.transcriptions.create(
169+
model=MODEL_NAME,
170+
file=winning_call,
171+
language="en",
172+
temperature=0.0,
173+
stream=True,
174+
extra_body=dict(stream_include_usage=True,
175+
stream_continuous_usage_stats=True),
176+
timeout=30)
177+
final = False
178+
continuous = True
179+
async for chunk in res:
180+
if not len(chunk.choices):
181+
# final usage sent
182+
final = True
183+
else:
184+
continuous = continuous and hasattr(chunk, 'usage')
185+
assert final and continuous
216186

217187

218188
@pytest.mark.asyncio
219-
async def test_sampling_params(mary_had_lamb):
189+
async def test_sampling_params(mary_had_lamb, client):
220190
"""
221191
Compare sampling with params and greedy sampling to assert results
222192
are different when extreme sampling parameters values are picked.
223193
"""
224-
model_name = "openai/whisper-small"
225-
server_args = ["--enforce-eager"]
226-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
227-
client = remote_server.get_async_client()
228-
transcription = await client.audio.transcriptions.create(
229-
model=model_name,
230-
file=mary_had_lamb,
231-
language="en",
232-
temperature=0.8,
233-
extra_body=dict(seed=42,
234-
repetition_penalty=1.9,
235-
top_k=12,
236-
top_p=0.4,
237-
min_p=0.5,
238-
frequency_penalty=1.8,
239-
presence_penalty=2.0))
240-
241-
greedy_transcription = await client.audio.transcriptions.create(
242-
model=model_name,
243-
file=mary_had_lamb,
244-
language="en",
245-
temperature=0.0,
246-
extra_body=dict(seed=42))
247-
248-
assert greedy_transcription.text != transcription.text
194+
transcription = await client.audio.transcriptions.create(
195+
model=MODEL_NAME,
196+
file=mary_had_lamb,
197+
language="en",
198+
temperature=0.8,
199+
extra_body=dict(seed=42,
200+
repetition_penalty=1.9,
201+
top_k=12,
202+
top_p=0.4,
203+
min_p=0.5,
204+
frequency_penalty=1.8,
205+
presence_penalty=2.0))
206+
207+
greedy_transcription = await client.audio.transcriptions.create(
208+
model=MODEL_NAME,
209+
file=mary_had_lamb,
210+
language="en",
211+
temperature=0.0,
212+
extra_body=dict(seed=42))
213+
214+
assert greedy_transcription.text != transcription.text
249215

250216

251217
@pytest.mark.asyncio
252-
async def test_audio_prompt(mary_had_lamb):
253-
model_name = "openai/whisper-large-v3-turbo"
254-
server_args = ["--enforce-eager"]
218+
async def test_audio_prompt(mary_had_lamb, client):
255219
prompt = "This is a speech, recorded in a phonograph."
256-
with RemoteOpenAIServer(model_name, server_args) as remote_server:
257-
#Prompts should not omit the part of original prompt while transcribing.
258-
prefix = "The first words I spoke in the original phonograph"
259-
client = remote_server.get_async_client()
260-
transcription = await client.audio.transcriptions.create(
261-
model=model_name,
262-
file=mary_had_lamb,
263-
language="en",
264-
response_format="text",
265-
temperature=0.0)
266-
out = json.loads(transcription)['text']
267-
assert prefix in out
268-
transcription_wprompt = await client.audio.transcriptions.create(
269-
model=model_name,
270-
file=mary_had_lamb,
271-
language="en",
272-
response_format="text",
273-
prompt=prompt,
274-
temperature=0.0)
275-
out_prompt = json.loads(transcription_wprompt)['text']
276-
assert prefix in out_prompt
220+
#Prompts should not omit the part of original prompt while transcribing.
221+
prefix = "The first words I spoke in the original phonograph"
222+
transcription = await client.audio.transcriptions.create(
223+
model=MODEL_NAME,
224+
file=mary_had_lamb,
225+
language="en",
226+
response_format="text",
227+
temperature=0.0)
228+
out = json.loads(transcription)['text']
229+
assert prefix in out
230+
transcription_wprompt = await client.audio.transcriptions.create(
231+
model=MODEL_NAME,
232+
file=mary_had_lamb,
233+
language="en",
234+
response_format="text",
235+
prompt=prompt,
236+
temperature=0.0)
237+
out_prompt = json.loads(transcription_wprompt)['text']
238+
assert prefix in out_prompt

0 commit comments

Comments
 (0)