1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ torch.export AOT Inductor Tutorial for Python runtime
5
+ ===================================================
6
+ **Author:** Ankith Gunapal
7
+ """
8
+
9
+ ######################################################################
10
+ #
11
+ # .. warning::
12
+ #
13
+ # ``torch._export.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
14
+ # breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
15
+ #
16
+ # It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
17
+ # to do Ahead-of-Time compilation of PyTorch exported models by creating
18
+ # a shared library that can be run in a non-Python environment.
19
+ #
20
+ #
21
+ # In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime.
22
+ # We will look at how to use :func:`torch._export.aot_compile` to generate a shared library.
23
+ # We also look at how we can run the shared library in python runtime using :func:`torch._export.aot_load`.
24
+ #
25
+ # **Contents**
26
+ #
27
+ # .. contents::
28
+ # :local:
29
+
30
+
31
+ ######################################################################
32
+ # Model Compilation
33
+ # ------------
34
+ #
35
+ # We will use TorchVision's pretrained `ResNet18` model in this example and use TorchInductor on the
36
+ # exported PyTorch program using :func:`torch._export.aot_compile`
37
+ #
38
+ # .. note::
39
+ #
40
+ # This API also supports :func:`torch.compile` options like `mode`
41
+ # As an example, if used on a CUDA enabled device, we can set `"max_autotune": True`
42
+ #
43
+ # We also specify `dynamic_shapes` for the batch dimension. In this example, min=2 is not a bug and is
44
+ # explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
45
+
46
+
47
+ import os
48
+ import torch
49
+ from torchvision .models import ResNet18_Weights , resnet18
50
+
51
+ model = resnet18 (weights = ResNet18_Weights .DEFAULT )
52
+ model .eval ()
53
+
54
+ with torch .inference_mode ():
55
+
56
+ # Specify the generated shared library path
57
+ aot_compile_options = {
58
+ "aot_inductor.output_path" : os .path .join (os .getcwd (), "resnet18_pt2.so" ),
59
+ }
60
+ if torch .cuda .is_available ():
61
+ device = "cuda"
62
+ aot_compile_options .update ({"max_autotune" : True })
63
+ else :
64
+ device = "cpu"
65
+ # We need to turn off the below optimizations to support batch_size = 16,
66
+ # which is treated like a special case
67
+ # https://github.com/pytorch/pytorch/pull/116152
68
+ torch .backends .mkldnn .set_flags (False )
69
+ torch .backends .nnpack .set_flags (False )
70
+
71
+ model = model .to (device = device )
72
+ example_inputs = (torch .randn (2 , 3 , 224 , 224 , device = device ),)
73
+
74
+ # min=2 is not a bug and is explained in the 0/1 Specialization Problem
75
+ batch_dim = torch .export .Dim ("batch" , min = 2 , max = 32 )
76
+ so_path = torch ._export .aot_compile (
77
+ model ,
78
+ example_inputs ,
79
+ # Specify the first dimension of the input x as dynamic
80
+ dynamic_shapes = {"x" : {0 : batch_dim }},
81
+ # Specify the generated shared library path
82
+ options = aot_compile_options
83
+ )
84
+
85
+
86
+ ######################################################################
87
+ # Model Inference in Python
88
+ # ------------
89
+ #
90
+ # Typically the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
91
+ # we added a new API :func:`torch._export.aot_load` to load the shared library in python runtime.
92
+ # The API follows a similar structure to the :func:`torch.jit.load` API . We specify the path
93
+ # of the shared library and the device where this should be loaded.
94
+ # .. note::
95
+ #
96
+ # We specify batch_size=1 for inference and it works even though we specified min=2 in
97
+ # :func:`torch._export.aot_compile`
98
+
99
+
100
+ import os
101
+ import torch
102
+
103
+ device = "cuda" if torch .cuda .is_available () else "cpu"
104
+ model_so_path = os .path .join (os .getcwd (), "resnet18_pt2.so" )
105
+
106
+ model = torch ._export .aot_load (model_so_path , device )
107
+ example_inputs = (torch .randn (1 , 3 , 224 , 224 , device = device ),)
108
+
109
+ with torch .inference_mode ():
110
+ output = model (example_inputs )
0 commit comments