Skip to content

Commit 19b025c

Browse files
committed
add whisper
1 parent ddb8c90 commit 19b025c

File tree

7 files changed

+401
-23
lines changed

7 files changed

+401
-23
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def test_get_untrained_model_with_inputs_text2text_generation(self):
104104
raise unittest.SkipTest(f"not working for {mid!r}")
105105
model(**inputs)
106106

107+
@hide_stdout()
108+
def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
109+
mid = "openai/whisper-tiny"
110+
data = get_untrained_model_with_inputs(mid, verbose=1)
111+
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
112+
model, inputs = data["model"], data["inputs"]
113+
model(**inputs)
114+
107115
@hide_stdout()
108116
def test_get_untrained_model_with_inputs_imagetext2text_generation(self):
109117
mid = "HuggingFaceM4/tiny-random-idefics"

_unittests/ut_torch_models/try_tasks.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,72 @@ def test_imagetext2text_generation(self):
8282

8383
print(generated_text[0])
8484

85+
@never_test()
86+
def test_automatic_speech_recognition(self):
87+
# clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k automatic_speech
88+
# https://huggingface.co/openai/whisper-tiny
89+
90+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
91+
from datasets import load_dataset
92+
93+
"""
94+
kwargs=dict(
95+
cache_position:T7s4,
96+
past_key_values:EncoderDecoderCache(
97+
self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]),
98+
cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]])
99+
),
100+
decoder_input_ids:T7s1x4,
101+
encoder_outputs:dict(last_hidden_state:T1s1x1500x384),
102+
use_cache:bool,return_dict:bool
103+
)
104+
kwargs=dict(
105+
cache_position:T7s1,
106+
past_key_values:EncoderDecoderCache(
107+
self_attention_cache=DynamicCache[serialized](#2[
108+
#4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64],
109+
#4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64]
110+
]),
111+
cross_attention_cache=DynamicCache[serialized](#2[
112+
#4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64],
113+
#4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64]
114+
]),
115+
),
116+
decoder_input_ids:T7s1x1,
117+
encoder_outputs:dict(last_hidden_state:T1s1x1500x384),
118+
use_cache:bool,return_dict:bool
119+
)
120+
"""
121+
122+
# load model and processor
123+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
124+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
125+
forced_decoder_ids = processor.get_decoder_prompt_ids(
126+
language="english", task="transcribe"
127+
)
128+
129+
# load streaming dataset and read first audio sample
130+
ds = load_dataset(
131+
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
132+
)
133+
sample = ds[0]["audio"]
134+
input_features = processor(
135+
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
136+
).input_features
137+
138+
# generate token ids
139+
print()
140+
with steel_forward(model):
141+
predicted_ids = model.generate(
142+
input_features, forced_decoder_ids=forced_decoder_ids
143+
)
144+
145+
# decode token ids to text
146+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
147+
print("--", transcription)
148+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
149+
print("--", transcription)
150+
85151

86152
if __name__ == "__main__":
87153
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,11 @@ def _debug(self):
11371137
"Tells if DEBUG=1 is set up."
11381138
return os.environ.get("DEBUG") in BOOLEAN_VALUES
11391139

1140+
def string_type(self, *args, **kwargs):
1141+
from .helpers import string_type
1142+
1143+
return string_type(*args, **kwargs)
1144+
11401145
def subloop(self, *args, verbose: int = 0):
11411146
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
11421147
if len(args) == 1:

onnx_diagnostic/helpers/helper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
22
import enum
33
import inspect
4+
from dataclasses import is_dataclass, fields
45
from typing import Any, Callable, Dict, List, Optional, Set
56
import numpy as np
67

@@ -140,6 +141,19 @@ def string_type(
140141
"""
141142
if obj is None:
142143
return "None"
144+
if is_dataclass(obj):
145+
values = {f.name: getattr(obj, f.name, None) for f in fields(obj)}
146+
values = {k: v for k, v in values.items() if v is not None}
147+
s = string_type(
148+
values,
149+
with_shape=with_shape,
150+
with_min_max=with_min_max,
151+
with_device=with_device,
152+
ignore=ignore,
153+
limit=limit,
154+
)
155+
return f"{obj.__class__.__name__}{s[4:]}"
156+
143157
# tuple
144158
if isinstance(obj, tuple):
145159
if len(obj) == 1:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
VitsModel,text-to-audio
120120
Wav2Vec2ConformerForCTC,automatic-speech-recognition
121121
Wav2Vec2Model,feature-extraction
122-
WhisperForConditionalGeneration,no-pipeline-tag
122+
WhisperForConditionalGeneration,automatic-speech-recognition
123123
XLMModel,feature-extraction
124124
XLMRobertaForCausalLM,text-generation
125125
YolosForObjectDetection,object-detection

0 commit comments

Comments
 (0)