Skip to content

Commit 4f86ecc

Browse files
committed
Files for Fine Tuning Whisper on Custom Dataset
1 parent a4d7454 commit 4f86ecc

File tree

8 files changed

+39032
-0
lines changed

8 files changed

+39032
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Fine Tuning Whisper on Custom Dataset
2+
3+
This folder contains the Jupyter Notebooks and Scripts for the LearnOpenCV article - **[Fine Tuning Whisper on Custom Dataset](https://learnopencv.com/fine-tuning-whisper-on-custom-dataset/)**.
4+
5+
We have provided notebooks for fine tuning Whisper - Tiny, Base, and Small models along with scripts for time comparison and Gradio UI.
6+
7+
You can download the trained weights from the below link.
8+
9+
[<img src="https://learnopencv.com/wp-content/uploads/2022/07/download-button-e1657285155454.png" alt="Download Code" width="200">](https://www.dropbox.com/scl/fo/f13dkz0373eq6hfdm0eoe/AG9cM_dKoCdfyhNPHYNC9eM?rlkey=3omlte477caelquynoms3xieg&st=qzzkrs3g&dl=1)
10+
11+
![](readme_images/whisper_fine_tuning_small_gradio.gif)
12+
13+
## AI Courses by OpenCV
14+
15+
Want to become an expert in AI? [AI Courses by OpenCV](https://opencv.org/courses/) is a great place to start.
16+
17+
[![img](https://learnopencv.com/wp-content/uploads/2023/01/AI-Courses-By-OpenCV-Github.png)](https://opencv.org/courses/)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Script to compare time for fine-tuned Whisper models.
3+
"""
4+
5+
import torch
6+
import time
7+
import os
8+
9+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
10+
11+
model_dirs = [
12+
'whisper_tiny_atco2_v2/best_model',
13+
'whisper_base_atco2/best_model',
14+
'whisper_small_atco2/best_model'
15+
]
16+
17+
input_dir = 'inference_data'
18+
19+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
20+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21+
22+
for model_id in model_dirs:
23+
print(f"\nEvaluating model: {model_id}")
24+
25+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
26+
model_id, torch_dtype=torch_dtype,
27+
low_cpu_mem_usage=True,
28+
use_safetensors=True
29+
)
30+
model.to(device)
31+
32+
processor = AutoProcessor.from_pretrained(model_id)
33+
34+
pipe = pipeline(
35+
'automatic-speech-recognition',
36+
model=model,
37+
tokenizer=processor.tokenizer,
38+
feature_extractor=processor.feature_extractor,
39+
torch_dtype=torch_dtype,
40+
device=device
41+
)
42+
43+
total_time = 0
44+
num_runs = 0
45+
46+
for _ in range(10):
47+
for filename in os.listdir(input_dir):
48+
if filename.endswith('.wav'):
49+
start_time = time.time()
50+
result = pipe(os.path.join(input_dir, filename))
51+
end_time = time.time()
52+
total_time += (end_time - start_time)
53+
num_runs += 1
54+
55+
average_time = total_time / num_runs
56+
print(f"\nAverage time taken for {model_id}: {average_time} seconds")

0 commit comments

Comments
 (0)