Skip to content

Commit 29cedfc

Browse files
Add onnx demos
1 parent 79d143d commit 29cedfc

File tree

10 files changed

+300
-0
lines changed

10 files changed

+300
-0
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.alexnet.image_classification.onnx import ModelLoader
7+
8+
9+
def run_alexnet_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("alexnet", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Print the results
25+
26+
27+
if __name__ == "__main__":
28+
run_alexnet_demo_case()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.densenet.image_classification.onnx import ModelLoader
7+
8+
9+
def run_densenet_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("densenet121", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Print the results
25+
loader.print_cls_results(output)
26+
27+
28+
if __name__ == "__main__":
29+
run_densenet_demo_case()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.googlenet.image_classification.onnx import ModelLoader
7+
8+
9+
def run_googlenet_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("googlenet", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Print the results
25+
loader.print_cls_results(output)
26+
27+
28+
if __name__ == "__main__":
29+
run_googlenet_demo_case()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.mobilenetv1.image_classification.onnx import ModelLoader
7+
8+
9+
def run_mobilenetv1_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("mobilenetv1", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Print the results
25+
loader.print_cls_results(output)
26+
27+
28+
if __name__ == "__main__":
29+
run_mobilenetv1_demo_case()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.resnet.image_classification.onnx import ModelLoader, ModelVariant
7+
8+
9+
def run_resnet_onnx(variant):
10+
loader = ModelLoader(variant=variant)
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs().contiguous()
16+
framework_model = forge.OnnxModule(variant.value, onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
loader.print_cls_results(output)
24+
25+
print("=" * 60, flush=True)
26+
27+
28+
if __name__ == "__main__":
29+
demo_cases = [
30+
ModelVariant.RESNET_50_HF,
31+
ModelVariant.RESNET_50_HF_HIGH_RES,
32+
ModelVariant.RESNET_50_TIMM,
33+
ModelVariant.RESNET_50_TIMM_HIGH_RES,
34+
ModelVariant.RESNET_18,
35+
ModelVariant.RESNET_34,
36+
ModelVariant.RESNET_50,
37+
ModelVariant.RESNET_50_HIGH_RES,
38+
ModelVariant.RESNET_101,
39+
ModelVariant.RESNET_152,
40+
]
41+
42+
# Run each demo case
43+
for variant in demo_cases:
44+
run_resnet_onnx(variant)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.vgg.image_classification.onnx import ModelLoader
7+
8+
9+
def run_vgg_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs().contiguous()
16+
framework_model = forge.OnnxModule("vgg11", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Print the results
25+
loader.print_cls_results(output)
26+
27+
28+
if __name__ == "__main__":
29+
run_vgg_demo_case()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.nbeats.time_series_forecasting.onnx import ModelLoader
7+
8+
9+
def run_nbeats_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("nbeats", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, inputs)
20+
21+
# Run inference on Tenstorrent device
22+
compiled_model(*inputs)
23+
24+
25+
if __name__ == "__main__":
26+
run_nbeats_demo_case()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.roberta.sequence_classification.onnx.loader import ModelLoader
7+
8+
9+
def run_roberta_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("roberta_sequence_classification", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Decode the output
25+
predicted_value = loader.decode_output(output)
26+
print(f"Predicted Sentiment: {predicted_value}")
27+
28+
29+
if __name__ == "__main__":
30+
run_roberta_demo_case()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.squeezebert.sequence_classification.onnx.loader import ModelLoader
7+
8+
9+
def run_squeezebert_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("squeezebert_sequence_classification", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, [inputs])
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(inputs)
23+
24+
# Decode the output
25+
predicted_value = loader.decode_output(output)
26+
print(f"Predicted Category: {predicted_value}")
27+
28+
29+
if __name__ == "__main__":
30+
run_squeezebert_demo_case()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import tempfile
5+
import forge
6+
from third_party.tt_forge_models.t5.causal_lm.onnx.loader import ModelLoader
7+
8+
9+
def run_t5_demo_case():
10+
loader = ModelLoader()
11+
with tempfile.TemporaryDirectory() as tmpdir:
12+
13+
# Load model and input
14+
onnx_model = loader.load_model(onnx_tmp_path=tmpdir)
15+
inputs = loader.load_inputs()
16+
framework_model = forge.OnnxModule("t5_causal_lm", onnx_model)
17+
18+
# Compile the model using Forge
19+
compiled_model = forge.compile(framework_model, inputs)
20+
21+
# Run inference on Tenstorrent device
22+
output = compiled_model(*inputs)
23+
24+
25+
if __name__ == "__main__":
26+
run_t5_demo_case()

0 commit comments

Comments
 (0)