|
| 1 | +import os |
| 2 | +import unittest |
| 3 | +from onnx_diagnostic.ext_test_case import ( |
| 4 | + ExtTestCase, |
| 5 | + requires_torch, |
| 6 | + requires_transformers, |
| 7 | + hide_stdout, |
| 8 | +) |
| 9 | +from onnx_diagnostic.helpers.model_builder_helper import ( |
| 10 | + download_model_builder_to_cache, |
| 11 | + import_model_builder, |
| 12 | + create_model_builder, |
| 13 | + save_model_builder, |
| 14 | +) |
| 15 | +from onnx_diagnostic.torch_models.hghub import ( |
| 16 | + get_untrained_model_with_inputs, |
| 17 | +) |
| 18 | +from onnx_diagnostic.helpers.rt_helper import make_feeds |
| 19 | + |
| 20 | + |
| 21 | +class TestModelBuilderHelper(ExtTestCase): |
| 22 | + # This is to limit impact on CI. |
| 23 | + @requires_transformers("4.52") |
| 24 | + @requires_torch("2.7.99") |
| 25 | + def test_download_model_builder(self): |
| 26 | + path = download_model_builder_to_cache() |
| 27 | + self.assertExists(path) |
| 28 | + builder = import_model_builder() |
| 29 | + self.assertHasAttr(builder, "create_model") |
| 30 | + |
| 31 | + # This is to limit impact on CI. |
| 32 | + @requires_transformers("4.52") |
| 33 | + @requires_torch("2.7.99") |
| 34 | + @hide_stdout() |
| 35 | + def test_model_builder_id(self): |
| 36 | + # clear&&python ~/.cache/onnx-diagnostic/builder.py |
| 37 | + # --model arnir0/Tiny-LLM -p fp16 -c dump_cache -e cpu -o dump_model |
| 38 | + folder = self.get_dump_folder("test_model_builder_id") |
| 39 | + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") |
| 40 | + onnx_model = create_model_builder( |
| 41 | + data["configuration"], |
| 42 | + data["model"], |
| 43 | + precision="fp32", |
| 44 | + execution_provider="cpu", |
| 45 | + cache_dir=folder, |
| 46 | + verbose=1, |
| 47 | + ) |
| 48 | + self.assertGreater(len(onnx_model.nodes), 5) |
| 49 | + |
| 50 | + proto = save_model_builder(onnx_model, verbose=1) |
| 51 | + import onnxruntime |
| 52 | + |
| 53 | + onnxruntime.InferenceSession( |
| 54 | + proto.SerializeToString(), providers=["CPUExecutionProvider"] |
| 55 | + ) |
| 56 | + |
| 57 | + # We need to start again. |
| 58 | + onnx_model = create_model_builder( |
| 59 | + data["configuration"], |
| 60 | + data["model"], |
| 61 | + precision="fp32", |
| 62 | + execution_provider="cpu", |
| 63 | + cache_dir=folder, |
| 64 | + verbose=1, |
| 65 | + ) |
| 66 | + save_model_builder(onnx_model, folder, verbose=1) |
| 67 | + model_name = os.path.join(folder, "model.onnx") |
| 68 | + self.assertExists(model_name) |
| 69 | + |
| 70 | + feeds = make_feeds(proto, data["inputs"], use_numpy=True) |
| 71 | + expected = data["model"](**data["inputs"]) |
| 72 | + |
| 73 | + sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"]) |
| 74 | + try: |
| 75 | + got = sess.run(None, feeds) |
| 76 | + except onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument as e: |
| 77 | + if "batch_size must be 1 when sequence_length > 1" in str(e): |
| 78 | + raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1") |
| 79 | + self.assertEqualAny(expected, got) |
| 80 | + |
| 81 | + |
| 82 | +if __name__ == "__main__": |
| 83 | + unittest.main(verbosity=2) |
0 commit comments