Skip to content

Commit 1dea278

Browse files
committed
Tutorial for AOTI Python runtime
1 parent f1c0b8a commit 1dea278

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
torch.export AOT Inductor Tutorial for Python runtime
5+
===================================================
6+
**Author:** Ankith Gunapal
7+
"""
8+
9+
######################################################################
10+
#
11+
# .. warning::
12+
#
13+
# ``torch._export.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
14+
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
15+
#
16+
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
17+
# to do Ahead-of-Time compilation of PyTorch exported models by creating
18+
# a shared library that can be run in a non-Python environment.
19+
#
20+
#
21+
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime.
22+
# We will look at how to use :func:`torch._export.aot_compile` to generate a shared library.
23+
# We also look at how we can run the shared library in python runtime using :func:`torch._export.aot_load`.
24+
#
25+
# **Contents**
26+
#
27+
# .. contents::
28+
# :local:
29+
30+
31+
######################################################################
32+
# Model Compilation
33+
# ------------
34+
#
35+
# We will use TorchVision's pretrained `ResNet18` model in this example and use TorchInductor on the
36+
# exported PyTorch program using :func:`torch._export.aot_compile`
37+
#
38+
# .. note::
39+
#
40+
# This API also supports :func:`torch.compile` options like `mode`
41+
# As an example, if used on a CUDA enabled device, we can set `"max_autotune": True`
42+
#
43+
# We also specify `dynamic_shapes` for the batch dimension. In this example, min=2 is not a bug and is
44+
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
45+
46+
47+
import os
48+
import torch
49+
from torchvision.models import ResNet18_Weights, resnet18
50+
51+
model = resnet18(weights=ResNet18_Weights.DEFAULT)
52+
model.eval()
53+
54+
with torch.inference_mode():
55+
56+
# Specify the generated shared library path
57+
aot_compile_options = {
58+
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
59+
}
60+
if torch.cuda.is_available():
61+
device = "cuda"
62+
aot_compile_options.update({"max_autotune": True})
63+
else:
64+
device = "cpu"
65+
# We need to turn off the below optimizations to support batch_size = 16,
66+
# which is treated like a special case
67+
# https://github.com/pytorch/pytorch/pull/116152
68+
torch.backends.mkldnn.set_flags(False)
69+
torch.backends.nnpack.set_flags(False)
70+
71+
model = model.to(device=device)
72+
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
73+
74+
# min=2 is not a bug and is explained in the 0/1 Specialization Problem
75+
batch_dim = torch.export.Dim("batch", min=2, max=32)
76+
so_path = torch._export.aot_compile(
77+
model,
78+
example_inputs,
79+
# Specify the first dimension of the input x as dynamic
80+
dynamic_shapes={"x": {0: batch_dim}},
81+
# Specify the generated shared library path
82+
options=aot_compile_options
83+
)
84+
85+
86+
######################################################################
87+
# Model Inference in Python
88+
# ------------
89+
#
90+
# Typically the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
91+
# we added a new API :func:`torch._export.aot_load` to load the shared library in python runtime.
92+
# The API follows a similar structure to the :func:`torch.jit.load` API . We specify the path
93+
# of the shared library and the device where this should be loaded.
94+
# .. note::
95+
#
96+
# We specify batch_size=1 for inference and it works even though we specified min=2 in
97+
# :func:`torch._export.aot_compile`
98+
99+
100+
import os
101+
import torch
102+
103+
device = "cuda" if torch.cuda.is_available() else "cpu"
104+
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
105+
106+
model = torch._export.aot_load(model_so_path, device)
107+
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
108+
109+
with torch.inference_mode():
110+
output = model(example_inputs)

0 commit comments

Comments
 (0)