Skip to content

Commit c8b3b29

Browse files
authored
[tests] Improve speed and reliability of test_transcription_api_correctness (#23854)
Signed-off-by: Russell Bryant <[email protected]>
1 parent 006477e commit c8b3b29

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/entrypoints/openai/correctness/test_transcription_api_correctness.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ async def transcribe_audio(client, tokenizer, y, sr):
4949
return latency, num_output_tokens, transcription.text
5050

5151

52-
async def bound_transcribe(model_name, sem, client, audio, reference):
53-
tokenizer = AutoTokenizer.from_pretrained(model_name)
52+
async def bound_transcribe(sem, client, tokenizer, audio, reference):
5453
# Use semaphore to limit concurrent requests.
5554
async with sem:
5655
result = await transcribe_audio(client, tokenizer, *audio)
@@ -63,15 +62,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference):
6362
async def process_dataset(model, client, data, concurrent_request):
6463
sem = asyncio.Semaphore(concurrent_request)
6564

65+
# Load tokenizer once outside the loop
66+
tokenizer = AutoTokenizer.from_pretrained(model)
67+
6668
# Warmup call as the first `librosa.load` server-side is quite slow.
6769
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
68-
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
70+
_ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "")
6971

7072
tasks: list[asyncio.Task] = []
7173
for sample in data:
7274
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
7375
task = asyncio.create_task(
74-
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
76+
bound_transcribe(sem, client, tokenizer, (audio, sr),
77+
sample["text"]))
7578
tasks.append(task)
7679
return await asyncio.gather(*tasks)
7780

0 commit comments

Comments
 (0)