@@ -82,6 +82,72 @@ def test_imagetext2text_generation(self):
8282
8383 print (generated_text [0 ])
8484
85+ @never_test ()
86+ def test_automatic_speech_recognition (self ):
87+ # clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k automatic_speech
88+ # https://huggingface.co/openai/whisper-tiny
89+
90+ from transformers import WhisperProcessor , WhisperForConditionalGeneration
91+ from datasets import load_dataset
92+
93+ """
94+ kwargs=dict(
95+ cache_position:T7s4,
96+ past_key_values:EncoderDecoderCache(
97+ self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]),
98+ cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]])
99+ ),
100+ decoder_input_ids:T7s1x4,
101+ encoder_outputs:dict(last_hidden_state:T1s1x1500x384),
102+ use_cache:bool,return_dict:bool
103+ )
104+ kwargs=dict(
105+ cache_position:T7s1,
106+ past_key_values:EncoderDecoderCache(
107+ self_attention_cache=DynamicCache[serialized](#2[
108+ #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64],
109+ #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64]
110+ ]),
111+ cross_attention_cache=DynamicCache[serialized](#2[
112+ #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64],
113+ #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64]
114+ ]),
115+ ),
116+ decoder_input_ids:T7s1x1,
117+ encoder_outputs:dict(last_hidden_state:T1s1x1500x384),
118+ use_cache:bool,return_dict:bool
119+ )
120+ """
121+
122+ # load model and processor
123+ processor = WhisperProcessor .from_pretrained ("openai/whisper-tiny" )
124+ model = WhisperForConditionalGeneration .from_pretrained ("openai/whisper-tiny" )
125+ forced_decoder_ids = processor .get_decoder_prompt_ids (
126+ language = "english" , task = "transcribe"
127+ )
128+
129+ # load streaming dataset and read first audio sample
130+ ds = load_dataset (
131+ "hf-internal-testing/librispeech_asr_dummy" , "clean" , split = "validation"
132+ )
133+ sample = ds [0 ]["audio" ]
134+ input_features = processor (
135+ sample ["array" ], sampling_rate = sample ["sampling_rate" ], return_tensors = "pt"
136+ ).input_features
137+
138+ # generate token ids
139+ print ()
140+ with steel_forward (model ):
141+ predicted_ids = model .generate (
142+ input_features , forced_decoder_ids = forced_decoder_ids
143+ )
144+
145+ # decode token ids to text
146+ transcription = processor .batch_decode (predicted_ids , skip_special_tokens = False )
147+ print ("--" , transcription )
148+ transcription = processor .batch_decode (predicted_ids , skip_special_tokens = True )
149+ print ("--" , transcription )
150+
85151
86152if __name__ == "__main__" :
87153 unittest .main (verbosity = 2 )
0 commit comments