Skip to content

Commit 4aa8399

Browse files
committed
Moved tutorial to recipe
1 parent 53f5965 commit 4aa8399

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

intermediate_source/torch_export_aoti_python.py renamed to recipes_source/torch_export_aoti_python.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
"""
44
(Beta) ``torch.export`` AOTInductor Tutorial for Python runtime
55
===================================================
6-
**Author:** Ankith Gunapal, Bin Bao
6+
**Author:** Ankith Gunapal, Bin Bao, Angela Yi
77
"""
88

99
######################################################################
1010
#
1111
# .. warning::
1212
#
13-
# ``torch._export.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
13+
# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
1414
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
1515
#
1616
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
@@ -19,8 +19,8 @@
1919
#
2020
#
2121
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime.
22-
# We will look at how to use :func:`torch._export.aot_compile` to generate a shared library.
23-
# Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
22+
# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a
23+
# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
2424
# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
2525
# ``max-autotune`` mode which can take some time to execute.
2626
#
@@ -33,14 +33,14 @@
3333
# Prerequisites
3434
# -------------
3535
# * PyTorch 2.4 or later
36-
# * Basic understanding of ``torch._export`` and AOTInductor
36+
# * Basic understanding of ``torch.export`` and AOTInductor
3737
# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial
3838

3939
######################################################################
4040
# What you will learn
4141
# ----------------------
4242
# * How to use AOTInductor for python runtime.
43-
# * How to use :func:`torch._export.aot_compile` to generate a shared library
43+
# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library
4444
# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
4545
# * When do you use AOTInductor for python runtime
4646

@@ -49,7 +49,7 @@
4949
# ------------
5050
#
5151
# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the
52-
# exported PyTorch program using :func:`torch._export.aot_compile`.
52+
# exported PyTorch program using :func:`torch._inductor.aot_compile`.
5353
#
5454
# .. note::
5555
#
@@ -115,7 +115,7 @@
115115
# .. note::
116116
#
117117
# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in
118-
# :func:`torch._export.aot_compile`.
118+
# :func:`torch.export.export`.
119119

120120

121121
import os
@@ -139,13 +139,13 @@
139139
# model deployment using Python.
140140
# There are mainly two reasons why you would use AOTInductor Python Runtime:
141141
#
142-
# - ``torch._export.aot_compile`` generates a shared library. This is useful for model
142+
# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
143143
# versioning for deployments and tracking model performance over time.
144144
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
145145
# cost associated with the first compilation. Your deployment needs to account for the
146146
# compilation time taken for the first inference. With AOTInductor, the compilation is
147-
# done offline using ``torch._export.aot_compile``. The deployment would only load the
148-
# shared library using ``torch._export.aot_load`` and run inference.
147+
# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
148+
# would only load the shared library using ``torch._export.aot_load`` and run inference.
149149
#
150150
#
151151
# The section below shows the speedup achieved with AOTInductor for first inference
@@ -218,7 +218,7 @@ def timed(fn):
218218
# ----------
219219
#
220220
# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
221-
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._export.aot_compile``
221+
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile``
222222
# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
223223
# generating a shared library and running it within a Python environment, even with dynamic shape
224224
# considerations and device-specific optimizations. We also looked at the advantage of using

0 commit comments

Comments
 (0)