|
| 1 | +# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +import tempfile |
| 5 | +import forge |
| 6 | +from third_party.tt_forge_models.resnet.image_classification.onnx import ModelLoader, ModelVariant |
| 7 | + |
| 8 | + |
| 9 | +def run_resnet_onnx(variant): |
| 10 | + loader = ModelLoader(variant=variant) |
| 11 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 12 | + |
| 13 | + # Load model and input |
| 14 | + onnx_model = loader.load_model(onnx_tmp_path=tmpdir) |
| 15 | + inputs = loader.load_inputs().contiguous() |
| 16 | + framework_model = forge.OnnxModule(variant.value, onnx_model) |
| 17 | + |
| 18 | + # Compile the model using Forge |
| 19 | + compiled_model = forge.compile(framework_model, [inputs]) |
| 20 | + |
| 21 | + # Run inference on Tenstorrent device |
| 22 | + output = compiled_model(inputs) |
| 23 | + |
| 24 | + # Print the results |
| 25 | + loader.print_cls_results(output) |
| 26 | + print("=" * 60, flush=True) |
| 27 | + |
| 28 | + |
| 29 | +if __name__ == "__main__": |
| 30 | + demo_cases = [ |
| 31 | + ModelVariant.RESNET_50_HF, |
| 32 | + ModelVariant.RESNET_50_HF_HIGH_RES, |
| 33 | + ModelVariant.RESNET_50_TIMM, |
| 34 | + ModelVariant.RESNET_50_TIMM_HIGH_RES, |
| 35 | + ModelVariant.RESNET_18, |
| 36 | + ModelVariant.RESNET_34, |
| 37 | + ModelVariant.RESNET_50, |
| 38 | + ModelVariant.RESNET_101, |
| 39 | + ModelVariant.RESNET_152, |
| 40 | + ] |
| 41 | + |
| 42 | + # Run each demo case |
| 43 | + for variant in demo_cases: |
| 44 | + run_resnet_onnx(variant) |
0 commit comments