Skip to content

Commit cf0c976

Browse files
feat(transcription): add pruna/whisper-v3-large transcription model (#51)
* feat(transcription): add Whisper transcription model support - Add `transcriptionModel()` and `transcription()` methods to the provider - Support audio transcription via RunPod's pruna/whisper-v3-large endpoint - Accept audio as Uint8Array, base64 string, or URL via providerOptions - Return transcription text, segments with timing, detected language, and duration - Add unit tests for the transcription model - Update README with transcription documentation * docs: use real demo audio URL in README examples * fix: resolve lint errors in transcription model
1 parent ca8225e commit cf0c976

File tree

9 files changed

+1084
-29
lines changed

9 files changed

+1084
-29
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
"@runpod/ai-sdk-provider": minor
3+
---
4+
5+
Add transcription model support with `pruna/whisper-v3-large`
6+
7+
- Add `transcriptionModel()` and `transcription()` methods to the provider
8+
- Support audio transcription via RunPod's Whisper endpoint
9+
- Accept audio as `Uint8Array`, base64 string, or URL via providerOptions
10+
- Return transcription text, segments with timing, detected language, and duration

README.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,94 @@ const result = await generateSpeech({
565565
});
566566
```
567567

568+
## Transcription Models
569+
570+
Transcribe audio using the AI SDK's `experimental_transcribe` and `runpod.transcription(...)`:
571+
572+
```ts
573+
import { runpod } from '@runpod/ai-sdk-provider';
574+
import { experimental_transcribe as transcribe } from 'ai';
575+
576+
const result = await transcribe({
577+
model: runpod.transcription('pruna/whisper-v3-large'),
578+
audio: new URL('https://image.runpod.ai/demo/transcription-demo.wav'),
579+
});
580+
581+
console.log(result.text);
582+
```
583+
584+
**Returns:**
585+
586+
- `result.text` - Full transcription text
587+
- `result.segments` - Array of segments with timing info
588+
- `segment.text` - Segment text
589+
- `segment.startSecond` - Start time in seconds
590+
- `segment.endSecond` - End time in seconds
591+
- `result.language` - Detected language code
592+
- `result.durationInSeconds` - Audio duration
593+
- `result.warnings` - Array of any warnings
594+
- `result.providerMetadata.runpod.jobId` - RunPod job ID
595+
596+
### Audio Input
597+
598+
You can provide audio in several ways:
599+
600+
```ts
601+
// URL (recommended for large files)
602+
const result = await transcribe({
603+
model: runpod.transcription('pruna/whisper-v3-large'),
604+
audio: new URL('https://image.runpod.ai/demo/transcription-demo.wav'),
605+
});
606+
607+
// Local file as Uint8Array
608+
import { readFileSync } from 'fs';
609+
const audioData = readFileSync('./audio.wav');
610+
611+
const result = await transcribe({
612+
model: runpod.transcription('pruna/whisper-v3-large'),
613+
audio: audioData,
614+
});
615+
```
616+
617+
### Examples
618+
619+
Check out our [examples](https://github.com/runpod/examples/tree/main/ai-sdk/getting-started) for more code snippets on how to use all the different models.
620+
621+
### Supported Models
622+
623+
- `pruna/whisper-v3-large`
624+
625+
### Provider Options
626+
627+
Use `providerOptions.runpod` for model-specific parameters:
628+
629+
| Option | Type | Default | Description |
630+
| ------------------- | --------- | ------- | ---------------------------------------------- |
631+
| `audio` | `string` | - | URL to audio file (alternative to binary data) |
632+
| `prompt` | `string` | - | Context prompt to guide transcription |
633+
| `language` | `string` | Auto | ISO-639-1 language code (e.g., 'en', 'es') |
634+
| `word_timestamps` | `boolean` | `false` | Include word-level timestamps |
635+
| `translate` | `boolean` | `false` | Translate audio to English |
636+
| `enable_vad` | `boolean` | `false` | Enable voice activity detection |
637+
| `maxPollAttempts` | `number` | `120` | Max polling attempts |
638+
| `pollIntervalMillis`| `number` | `2000` | Polling interval (ms) |
639+
640+
**Example (providerOptions):**
641+
642+
```ts
643+
const result = await transcribe({
644+
model: runpod.transcription('pruna/whisper-v3-large'),
645+
audio: new URL('https://image.runpod.ai/demo/transcription-demo.wav'),
646+
providerOptions: {
647+
runpod: {
648+
language: 'en',
649+
prompt: 'This is a demo of audio transcription',
650+
word_timestamps: true,
651+
},
652+
},
653+
});
654+
```
655+
568656
## About Runpod
569657

570658
[Runpod](https://runpod.io) is the foundation for developers to build, deploy, and scale custom AI systems.

src/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@ export type { RunpodProvider, RunpodProviderSettings } from './runpod-provider';
33
export type { RunpodChatModelId } from './runpod-chat-options';
44
export type { RunpodCompletionModelId } from './runpod-completion-options';
55
export type { RunpodImageModelId } from './runpod-image-options';
6+
export type {
7+
RunpodTranscriptionModelId,
8+
RunpodTranscriptionProviderOptions,
9+
} from './runpod-transcription-options';
610
export type { OpenAICompatibleErrorData as RunpodErrorData } from '@ai-sdk/openai-compatible';
711
export type { RunpodImageErrorData } from './runpod-error';

src/runpod-error.ts

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,44 @@ export const runpodImageErrorSchema = z.object({
99

1010
export type RunpodImageErrorData = z.infer<typeof runpodImageErrorSchema>;
1111

12-
export const runpodImageFailedResponseHandler = createJsonErrorResponseHandler({
13-
errorSchema: runpodImageErrorSchema as any,
14-
errorToMessage: (data: RunpodImageErrorData) => {
15-
// Prefer message if available (more descriptive)
16-
if (data.message) {
17-
return data.message;
18-
}
19-
20-
// If error field exists, try to extract nested JSON message
21-
if (data.error) {
22-
// Runpod sometimes returns nested JSON in the error field like:
23-
// "Error submitting task: 400, {\"code\":400,\"message\":\"...\"}"
24-
// Try to extract the inner message for cleaner error messages
25-
// Find the last occurrence of { which likely starts the JSON object
26-
const lastBraceIndex = data.error.lastIndexOf('{');
27-
if (lastBraceIndex !== -1) {
28-
try {
29-
const jsonStr = data.error.substring(lastBraceIndex);
30-
const nestedError = JSON.parse(jsonStr);
31-
if (nestedError.message && typeof nestedError.message === 'string') {
32-
return nestedError.message;
33-
}
34-
} catch {
35-
// If parsing fails, fall back to the original error string
12+
// Helper function to extract error message from Runpod error data
13+
function extractErrorMessage(data: RunpodImageErrorData): string {
14+
// Prefer message if available (more descriptive)
15+
if (data.message) {
16+
return data.message;
17+
}
18+
19+
// If error field exists, try to extract nested JSON message
20+
if (data.error) {
21+
// Runpod sometimes returns nested JSON in the error field like:
22+
// "Error submitting task: 400, {\"code\":400,\"message\":\"...\"}"
23+
// Try to extract the inner message for cleaner error messages
24+
// Find the last occurrence of { which likely starts the JSON object
25+
const lastBraceIndex = data.error.lastIndexOf('{');
26+
if (lastBraceIndex !== -1) {
27+
try {
28+
const jsonStr = data.error.substring(lastBraceIndex);
29+
const nestedError = JSON.parse(jsonStr);
30+
if (nestedError.message && typeof nestedError.message === 'string') {
31+
return nestedError.message;
3632
}
33+
} catch {
34+
// If parsing fails, fall back to the original error string
3735
}
38-
return data.error;
3936
}
40-
41-
return 'Unknown Runpod error';
42-
},
37+
return data.error;
38+
}
39+
40+
return 'Unknown Runpod error';
41+
}
42+
43+
export const runpodImageFailedResponseHandler = createJsonErrorResponseHandler({
44+
errorSchema: runpodImageErrorSchema as any,
45+
errorToMessage: extractErrorMessage,
46+
});
47+
48+
export const runpodTranscriptionFailedResponseHandler = createJsonErrorResponseHandler({
49+
errorSchema: runpodImageErrorSchema as any,
50+
errorToMessage: extractErrorMessage,
4351
});
4452

src/runpod-provider.test.ts

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
} from '@ai-sdk/openai-compatible';
66
import { RunpodImageModel } from './runpod-image-model';
77
import { RunpodSpeechModel } from './runpod-speech-model';
8+
import { RunpodTranscriptionModel } from './runpod-transcription-model';
89
import { loadApiKey } from '@ai-sdk/provider-utils';
910
import { createRunpod } from './runpod-provider';
1011
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
@@ -26,6 +27,10 @@ vi.mock('./runpod-speech-model', () => ({
2627
RunpodSpeechModel: vi.fn(),
2728
}));
2829

30+
vi.mock('./runpod-transcription-model', () => ({
31+
RunpodTranscriptionModel: vi.fn(),
32+
}));
33+
2934
vi.mock('@ai-sdk/provider-utils', () => ({
3035
loadApiKey: vi.fn().mockReturnValue('mock-api-key'),
3136
withoutTrailingSlash: vi.fn((url) => url),
@@ -245,4 +250,58 @@ describe('RunpodProvider', () => {
245250
);
246251
});
247252
});
253+
254+
describe('transcriptionModel', () => {
255+
it('should use mapping for known transcription model IDs', () => {
256+
const provider = createRunpod();
257+
258+
provider.transcriptionModel('pruna/whisper-v3-large');
259+
260+
expect((RunpodTranscriptionModel as any).mock.calls[0][0]).toBe(
261+
'pruna/whisper-v3-large'
262+
);
263+
expect((RunpodTranscriptionModel as any).mock.calls[0][1].baseURL).toBe(
264+
'https://api.runpod.ai/v2/whisper-v3-large'
265+
);
266+
});
267+
268+
it('should construct a transcription model for a serverless endpoint id', () => {
269+
const provider = createRunpod();
270+
const modelId = 'uhyz0hnkemrk6r';
271+
272+
const model = provider.transcriptionModel(modelId);
273+
expect(model).toBeInstanceOf(RunpodTranscriptionModel);
274+
275+
expect((RunpodTranscriptionModel as any).mock.calls[0][0]).toBe(modelId);
276+
expect((RunpodTranscriptionModel as any).mock.calls[0][1].baseURL).toBe(
277+
`https://api.runpod.ai/v2/${modelId}`
278+
);
279+
});
280+
281+
it('should accept a Runpod Console endpoint URL', () => {
282+
const provider = createRunpod();
283+
const url =
284+
'https://console.runpod.io/serverless/user/endpoint/uhyz0hnkemrk6r';
285+
286+
provider.transcriptionModel(url);
287+
288+
expect((RunpodTranscriptionModel as any).mock.calls[0][0]).toBe(
289+
'uhyz0hnkemrk6r'
290+
);
291+
expect((RunpodTranscriptionModel as any).mock.calls[0][1].baseURL).toBe(
292+
'https://api.runpod.ai/v2/uhyz0hnkemrk6r'
293+
);
294+
});
295+
});
296+
297+
describe('transcription', () => {
298+
it('should be an alias for transcriptionModel', () => {
299+
const provider = createRunpod();
300+
const modelId = 'pruna/whisper-v3-large';
301+
302+
const model = provider.transcription(modelId);
303+
304+
expect(model).toBeInstanceOf(RunpodTranscriptionModel);
305+
});
306+
});
248307
});

src/runpod-provider.ts

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
import { ImageModelV3, LanguageModelV3, SpeechModelV3 } from '@ai-sdk/provider';
1+
import {
2+
ImageModelV3,
3+
LanguageModelV3,
4+
SpeechModelV3,
5+
TranscriptionModelV3,
6+
} from '@ai-sdk/provider';
27
import {
38
OpenAICompatibleChatLanguageModel,
49
OpenAICompatibleCompletionLanguageModel,
@@ -10,6 +15,7 @@ import {
1015
} from '@ai-sdk/provider-utils';
1116
import { RunpodImageModel } from './runpod-image-model';
1217
import { RunpodSpeechModel } from './runpod-speech-model';
18+
import { RunpodTranscriptionModel } from './runpod-transcription-model';
1319

1420
export interface RunpodProviderSettings {
1521
/**
@@ -72,6 +78,16 @@ Creates a speech model for speech generation.
7278
Creates a speech model for speech generation.
7379
*/
7480
speech(modelId: string): SpeechModelV3;
81+
82+
/**
83+
Creates a transcription model for audio transcription.
84+
*/
85+
transcriptionModel(modelId: string): TranscriptionModelV3;
86+
87+
/**
88+
Creates a transcription model for audio transcription.
89+
*/
90+
transcription(modelId: string): TranscriptionModelV3;
7591
}
7692

7793
// Mapping of Runpod model IDs to their endpoint URLs
@@ -123,6 +139,11 @@ const SPEECH_MODEL_ID_TO_ENDPOINT_URL: Record<string, string> = {
123139
'resembleai/chatterbox-turbo': 'https://api.runpod.ai/v2/chatterbox-turbo/',
124140
};
125141

142+
// Mapping of Runpod transcription model IDs to their serverless endpoint URLs
143+
const TRANSCRIPTION_MODEL_ID_TO_ENDPOINT_URL: Record<string, string> = {
144+
'pruna/whisper-v3-large': 'https://api.runpod.ai/v2/whisper-v3-large',
145+
};
146+
126147
// Mapping of Runpod model IDs to their OpenAI model names
127148
const MODEL_ID_TO_OPENAI_NAME: Record<string, string> = {
128149
'qwen/qwen3-32b-awq': 'Qwen/Qwen3-32B-AWQ',
@@ -272,6 +293,28 @@ export function createRunpod(
272293
});
273294
};
274295

296+
const createTranscriptionModel = (modelId: string) => {
297+
const endpointIdFromConsole = parseRunpodConsoleEndpointId(modelId);
298+
const normalizedModelId = endpointIdFromConsole ?? modelId;
299+
300+
// Prefer explicit mapping for known transcription model IDs.
301+
const mappedBaseURL =
302+
TRANSCRIPTION_MODEL_ID_TO_ENDPOINT_URL[normalizedModelId];
303+
304+
const baseURL =
305+
mappedBaseURL ??
306+
(normalizedModelId.startsWith('http')
307+
? normalizedModelId
308+
: `https://api.runpod.ai/v2/${normalizedModelId}`);
309+
310+
return new RunpodTranscriptionModel(normalizedModelId, {
311+
provider: 'runpod.transcription',
312+
baseURL,
313+
headers: getHeaders,
314+
fetch: runpodFetch,
315+
});
316+
};
317+
275318
const provider = (modelId: string) => createChatModel(modelId);
276319

277320
provider.completionModel = createCompletionModel;
@@ -281,6 +324,8 @@ export function createRunpod(
281324
provider.image = createImageModel;
282325
provider.speechModel = createSpeechModel;
283326
provider.speech = createSpeechModel;
327+
provider.transcriptionModel = createTranscriptionModel;
328+
provider.transcription = createTranscriptionModel;
284329

285330
return provider;
286331
}

0 commit comments

Comments
 (0)