27
27
# .. contents::
28
28
# :local:
29
29
30
+ ######################################################################
31
+ # Prerequisites
32
+ # -------------
33
+ # * PyTorch 2.4 or later
34
+ # * Basic understanding of ``torch._export`` and AOT Inductor
35
+ # * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial
36
+
37
+ ######################################################################
38
+ # What you will learn
39
+ # ----------------------
40
+ # * How to use AOTInductor for python runtime.
41
+ # * How to use :func:`torch._export.aot_compile` to generate a shared library
42
+ # * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
30
43
31
44
######################################################################
32
45
# Model Compilation
37
50
#
38
51
# .. note::
39
52
#
40
- # This API also supports :func:`torch.compile` options like `mode`
41
- # As an example, if used on a CUDA enabled device, we can set `"max_autotune": True`
53
+ # This API also supports :func:`torch.compile` options like ``mode``
54
+ # This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
55
+ # which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
42
56
#
43
57
# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is
44
58
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
89
103
#
90
104
# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
91
105
# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
92
- # The API follows a similar structure to the :func:`torch.jit.load` API . We specify the path
106
+ # The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path
93
107
# of the shared library and the device where it should be loaded.
94
108
# .. note::
95
109
#
107
121
example_inputs = (torch .randn (1 , 3 , 224 , 224 , device = device ),)
108
122
109
123
with torch .inference_mode ():
110
- output = model (example_inputs )
124
+ output = model (example_inputs )
125
+
126
+ ######################################################################
127
+ # When to use AOT Inductor Python Runtime
128
+ # ---------------------------------------
129
+ #
130
+ # One of the requirements for using AOT Inductor is that the model shouldn't have any graph breaks.
131
+ # Once this requirement is met, the primary use case for using AOT Inductor Python Runtime is for
132
+ # model deployment using Python.
133
+ # There are mainly two reasons why you would use AOT Inductor Python Runtime:
134
+ #
135
+ # - ``torch._export.aot_compile`` generates a shared library. This is useful for model
136
+ # versioning for deployments and tracking model performance over time.
137
+ # - With :func:`torch.compile` being a JIT compiler, there is a warmup
138
+ # cost associated with the first compilation. Your deployment needs to account for the
139
+ # compilation time taken for the first inference. With AOT Inductor, the compilation is
140
+ # done offline using ``torch._export.aot_compile``. The deployment would only load the
141
+ # shared library using ``torch._export.aot_load`` and run inference.
142
+ #
143
+ #
144
+ # The section below shows the speedup achieved with AOT Inductor for first inference
145
+ #
146
+ # We define a utility function ``timed`` to measure the time taken for inference
147
+ #
148
+
149
+ import time
150
+ def timed (fn ):
151
+ # Returns the result of running `fn()` and the time it took for `fn()` to run,
152
+ # in seconds. We use CUDA events and synchronization for accurate
153
+ # measurement on CUDA enabled devices.
154
+ if torch .cuda .is_available ():
155
+ start = torch .cuda .Event (enable_timing = True )
156
+ end = torch .cuda .Event (enable_timing = True )
157
+ start .record ()
158
+ else :
159
+ start = time .time ()
160
+
161
+ result = fn ()
162
+ if torch .cuda .is_available ():
163
+ end .record ()
164
+ torch .cuda .synchronize ()
165
+ else :
166
+ end = time .time ()
167
+
168
+ # Measure time taken to execute the function in miliseconds
169
+ if torch .cuda .is_available ():
170
+ duration = start .elapsed_time (end )
171
+ else :
172
+ duration = (end - start ) * 1000
173
+
174
+ return result , duration
175
+
176
+
177
+ ######################################################################
178
+ # Lets measure the time for first inference using AOT Inductor
179
+
180
+ torch ._dynamo .reset ()
181
+
182
+ model = torch ._export .aot_load (model_so_path , device )
183
+ example_inputs = (torch .randn (1 , 3 , 224 , 224 , device = device ),)
184
+
185
+ with torch .inference_mode ():
186
+ _ , time_taken = timed (lambda : model (example_inputs ))
187
+ print (f"Time taken for first inference for AOT Inductor is { time_taken :.2f} ms" )
188
+
189
+
190
+ ######################################################################
191
+ # Lets measure the time for first inference using ``torch.compile``
192
+
193
+ torch ._dynamo .reset ()
194
+
195
+ model = resnet18 (weights = ResNet18_Weights .DEFAULT ).to (device )
196
+ model .eval ()
197
+
198
+ model = torch .compile (model )
199
+ example_inputs = torch .randn (1 , 3 , 224 , 224 , device = device )
200
+
201
+ with torch .inference_mode ():
202
+ _ , time_taken = timed (lambda : model (example_inputs ))
203
+ print (f"Time taken for first inference for torch.compile is { time_taken :.2f} ms" )
204
+
205
+ ######################################################################
206
+ # We see that there is a drastic speedup in first inference time using AOT Inductor compared
207
+ # to ``torch.compile``
208
+
209
+ ######################################################################
210
+ # Conclusion
211
+ # ----------
212
+ #
213
+ # In this tutorial, we have learned how to effectively use the AOTInductor for Python runtime by
214
+ # compiling and loading a pretrained ``ResNet18`` model using the ``torch._export.aot_compile``
215
+ # and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
216
+ # generating a shared library and running it within a Python environment, even with dynamic shape
217
+ # considerations and device-specific optimizations. We also looked at the advantage of using
218
+ # AOT Inductor in model deployments, with regards to speed up in first inference time.
0 commit comments