|
11 | 11 | requires_onnxscript, |
12 | 12 | requires_transformers, |
13 | 13 | ) |
14 | | -from onnx_diagnostic.torch_models.test_helper import ( |
| 14 | +from onnx_diagnostic.torch_models.validate import ( |
15 | 15 | get_inputs_for_task, |
16 | 16 | validate_model, |
17 | 17 | filter_inputs, |
|
21 | 21 | from onnx_diagnostic.tasks import supported_tasks |
22 | 22 |
|
23 | 23 |
|
24 | | -class TestTestHelper(ExtTestCase): |
| 24 | +class TestValidateWholeModels(ExtTestCase): |
25 | 25 | def test_get_inputs_for_task(self): |
26 | 26 | fcts = supported_tasks() |
27 | 27 | for task in self.subloop(sorted(fcts)): |
@@ -221,14 +221,39 @@ def test_validate_model_modelbuilder(self): |
221 | 221 | do_run=True, |
222 | 222 | verbose=10, |
223 | 223 | exporter="modelbuilder", |
224 | | - dump_folder="dump_test_validate_model_onnx_dynamo", |
| 224 | + dump_folder="dump_test_validate_model_modelbuilder", |
225 | 225 | ) |
226 | 226 | self.assertIsInstance(summary, dict) |
227 | 227 | self.assertIsInstance(data, dict) |
228 | 228 | self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) |
229 | 229 | onnx_filename = data["onnx_filename"] |
230 | 230 | self.assertExists(onnx_filename) |
231 | 231 |
|
| 232 | + @requires_torch("2.7") |
| 233 | + @hide_stdout() |
| 234 | + @ignore_warnings(FutureWarning) |
| 235 | + @requires_transformers("4.51") |
| 236 | + def test_validate_model_vit_model(self): |
| 237 | + mid = "ydshieh/tiny-random-ViTForImageClassification" |
| 238 | + summary, data = validate_model( |
| 239 | + mid, |
| 240 | + do_run=True, |
| 241 | + verbose=10, |
| 242 | + exporter="onnx-dynamo", |
| 243 | + dump_folder="dump_test_validate_model_onnx_dynamo", |
| 244 | + inputs2=True, |
| 245 | + ) |
| 246 | + self.assertIsInstance(summary, dict) |
| 247 | + self.assertIsInstance(data, dict) |
| 248 | + self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3) |
| 249 | + self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3) |
| 250 | + self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"]) |
| 251 | + self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"]) |
| 252 | + self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"]) |
| 253 | + self.assertEqual("#1[A1s3x2]", summary["run_output_inputs2"]) |
| 254 | + onnx_filename = data["onnx_filename"] |
| 255 | + self.assertExists(onnx_filename) |
| 256 | + |
232 | 257 |
|
233 | 258 | if __name__ == "__main__": |
234 | 259 | unittest.main(verbosity=2) |
0 commit comments