|
6 | 6 | has_transformers, |
7 | 7 | requires_transformers, |
8 | 8 | ) |
| 9 | +from onnx_diagnostic.helpers.torch_helper import to_any |
9 | 10 | from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs |
10 | 11 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
11 | 12 | from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str |
@@ -42,12 +43,13 @@ def test_text_generation(self): |
42 | 43 | ) |
43 | 44 |
|
44 | 45 | @hide_stdout() |
45 | | - def test_automatic_speech_recognition(self): |
| 46 | + def test_automatic_speech_recognition_float32(self): |
46 | 47 | mid = "openai/whisper-tiny" |
47 | 48 | data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) |
48 | 49 | self.assertEqual(data["task"], "automatic-speech-recognition") |
49 | 50 | self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)]) |
50 | 51 | model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] |
| 52 | + model(**data["inputs"]) |
51 | 53 | model(**data["inputs2"]) |
52 | 54 | Dim = torch.export.Dim |
53 | 55 | self.maxDiff = None |
@@ -113,6 +115,83 @@ def test_automatic_speech_recognition(self): |
113 | 115 | model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
114 | 116 | ) |
115 | 117 |
|
| 118 | + @hide_stdout() |
| 119 | + def test_automatic_speech_recognition_float16(self): |
| 120 | + mid = "openai/whisper-tiny" |
| 121 | + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) |
| 122 | + self.assertEqual(data["task"], "automatic-speech-recognition") |
| 123 | + self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)]) |
| 124 | + self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"])) |
| 125 | + data["inputs"] = to_any(data["inputs"], torch.float16) |
| 126 | + self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"])) |
| 127 | + data["inputs2"] = to_any(data["inputs2"], torch.float16) |
| 128 | + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] |
| 129 | + model = to_any(model, torch.float16) |
| 130 | + model(**data["inputs2"]) |
| 131 | + Dim = torch.export.Dim |
| 132 | + self.maxDiff = None |
| 133 | + self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds)) |
| 134 | + self.assertEqualAny( |
| 135 | + { |
| 136 | + "decoder_input_ids": { |
| 137 | + 0: Dim("batch", min=1, max=1024), |
| 138 | + 1: "seq_length", |
| 139 | + }, |
| 140 | + "cache_position": {0: "seq_length"}, |
| 141 | + "encoder_outputs": [{0: Dim("batch", min=1, max=1024)}], |
| 142 | + "past_key_values": [ |
| 143 | + [ |
| 144 | + [ |
| 145 | + {0: Dim("batch", min=1, max=1024)}, |
| 146 | + {0: Dim("batch", min=1, max=1024)}, |
| 147 | + ], |
| 148 | + [ |
| 149 | + {0: Dim("batch", min=1, max=1024)}, |
| 150 | + {0: Dim("batch", min=1, max=1024)}, |
| 151 | + ], |
| 152 | + ], |
| 153 | + [ |
| 154 | + [ |
| 155 | + {0: Dim("batch", min=1, max=1024)}, |
| 156 | + {0: Dim("batch", min=1, max=1024)}, |
| 157 | + ], |
| 158 | + [ |
| 159 | + {0: Dim("batch", min=1, max=1024)}, |
| 160 | + {0: Dim("batch", min=1, max=1024)}, |
| 161 | + ], |
| 162 | + ], |
| 163 | + ], |
| 164 | + }, |
| 165 | + ds, |
| 166 | + ) |
| 167 | + self.assertEqual( |
| 168 | + "#1[T10r3]", |
| 169 | + self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]), |
| 170 | + ) |
| 171 | + with torch_export_patches(patch_transformers=True, verbose=10): |
| 172 | + model(**inputs) |
| 173 | + flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0] |
| 174 | + self.assertIsInstance(flat, list) |
| 175 | + self.assertIsInstance(flat[0], torch.Tensor) |
| 176 | + self.assertEqual( |
| 177 | + "#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]", |
| 178 | + self.string_type(flat), |
| 179 | + ) |
| 180 | + torch.export.export( |
| 181 | + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
| 182 | + ) |
| 183 | + with torch_export_patches(patch_transformers=True, verbose=10): |
| 184 | + flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0] |
| 185 | + self.assertIsInstance(flat, list) |
| 186 | + self.assertIsInstance(flat[0], torch.Tensor) |
| 187 | + self.assertEqual( |
| 188 | + "#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]", |
| 189 | + self.string_type(flat), |
| 190 | + ) |
| 191 | + torch.export.export( |
| 192 | + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False |
| 193 | + ) |
| 194 | + |
116 | 195 | @hide_stdout() |
117 | 196 | def test_fill_mask(self): |
118 | 197 | mid = "google-bert/bert-base-multilingual-cased" |
|
0 commit comments