-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Tutorial for AOTI Python runtime #2997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1dea278
cd09129
3fa9b20
7c9edb7
9cba6fb
a6f6cd9
1375373
7158985
53f5965
849c8e3
4aa8399
39b3942
35c5dc8
71acd96
7f5fde9
790f762
45df5d0
b268a3c
b6c3a01
6578d82
9ee64d9
67bc080
fc0ff5e
85f2870
cb8ea23
194388e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,110 @@ | ||||||
# -*- coding: utf-8 -*- | ||||||
|
||||||
""" | ||||||
torch.export AOT Inductor Tutorial for Python runtime | ||||||
=================================================== | ||||||
**Author:** Ankith Gunapal | ||||||
""" | ||||||
|
||||||
###################################################################### | ||||||
# | ||||||
# .. warning:: | ||||||
# | ||||||
# ``torch._export.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility | ||||||
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime. | ||||||
# | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used | ||||||
# to do Ahead-of-Time compilation of PyTorch exported models by creating | ||||||
# a shared library that can be run in a non-Python environment. | ||||||
# | ||||||
# | ||||||
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will make the story more complete by explaining the "why" part here, e.g. eliminating recompilation at run time, max-autotune ahead of time, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. Haven't mentioned eliminating recompilation, since the tutorial doesn't show that |
||||||
# We will look at how to use :func:`torch._export.aot_compile` to generate a shared library. | ||||||
# We also look at how we can run the shared library in python runtime using :func:`torch._export.aot_load`. | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# | ||||||
# **Contents** | ||||||
# | ||||||
# .. contents:: | ||||||
# :local: | ||||||
|
||||||
|
||||||
###################################################################### | ||||||
# Model Compilation | ||||||
# ------------ | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# We will use TorchVision's pretrained `ResNet18` model in this example and use TorchInductor on the | ||||||
# exported PyTorch program using :func:`torch._export.aot_compile` | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# | ||||||
# .. note:: | ||||||
# | ||||||
# This API also supports :func:`torch.compile` options like `mode` | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# As an example, if used on a CUDA enabled device, we can set `"max_autotune": True` | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# | ||||||
# We also specify `dynamic_shapes` for the batch dimension. In this example, min=2 is not a bug and is | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__ | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
import os | ||||||
import torch | ||||||
from torchvision.models import ResNet18_Weights, resnet18 | ||||||
|
||||||
model = resnet18(weights=ResNet18_Weights.DEFAULT) | ||||||
model.eval() | ||||||
|
||||||
with torch.inference_mode(): | ||||||
|
||||||
# Specify the generated shared library path | ||||||
aot_compile_options = { | ||||||
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"), | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
if torch.cuda.is_available(): | ||||||
device = "cuda" | ||||||
aot_compile_options.update({"max_autotune": True}) | ||||||
else: | ||||||
device = "cpu" | ||||||
# We need to turn off the below optimizations to support batch_size = 16, | ||||||
# which is treated like a special case | ||||||
# https://github.com/pytorch/pytorch/pull/116152 | ||||||
torch.backends.mkldnn.set_flags(False) | ||||||
torch.backends.nnpack.set_flags(False) | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
model = model.to(device=device) | ||||||
example_inputs = (torch.randn(2, 3, 224, 224, device=device),) | ||||||
|
||||||
# min=2 is not a bug and is explained in the 0/1 Specialization Problem | ||||||
batch_dim = torch.export.Dim("batch", min=2, max=32) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it is ok to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An example with batch_size 1 is usually tried often, hence I set |
||||||
so_path = torch._export.aot_compile( | ||||||
model, | ||||||
example_inputs, | ||||||
# Specify the first dimension of the input x as dynamic | ||||||
dynamic_shapes={"x": {0: batch_dim}}, | ||||||
# Specify the generated shared library path | ||||||
options=aot_compile_options | ||||||
) | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
###################################################################### | ||||||
# Model Inference in Python | ||||||
# ------------ | ||||||
# | ||||||
# Typically the shared object generated above is used in a non-Python environment. In PyTorch 2.3, | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# we added a new API :func:`torch._export.aot_load` to load the shared library in python runtime. | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# The API follows a similar structure to the :func:`torch.jit.load` API . We specify the path | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# of the shared library and the device where this should be loaded. | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# .. note:: | ||||||
# | ||||||
# We specify batch_size=1 for inference and it works even though we specified min=2 in | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# :func:`torch._export.aot_compile` | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
import os | ||||||
import torch | ||||||
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so") | ||||||
|
||||||
model = torch._export.aot_load(model_so_path, device) | ||||||
example_inputs = (torch.randn(1, 3, 224, 224, device=device),) | ||||||
|
||||||
with torch.inference_mode(): | ||||||
output = model(example_inputs) | ||||||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.