Skip to content

Commit 94d0612

Browse files
committed
address comments
1 parent a4326eb commit 94d0612

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

intermediate_source/compiled_autograd_tutorial.rst

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
1616

1717
* PyTorch 2.4
1818
* Complete the `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
19+
* Read through the TorchDynamo and AOTAutograd sections of `Get Started with PyTorch 2.x <https://pytorch.org/get-started/pytorch-2.0/>`_
1920

2021
Overview
2122
--------
@@ -41,7 +42,7 @@ However, Compiled Autograd introduces its own limitations:
4142
Setup
4243
-----
4344
In this tutorial, we will base our examples on this simple neural network model.
44-
It takes a a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
45+
It takes a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
4546

4647
.. code:: python
4748
@@ -57,7 +58,7 @@ It takes a a 10-dimensional input vector, processes it through a single linear l
5758
5859
Basic usage
5960
------------
60-
Before calling the torch.compile API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
61+
Before calling the ``torch.compile`` API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
6162

6263
.. code:: python
6364
@@ -72,19 +73,19 @@ Before calling the torch.compile API, make sure to set ``torch._dynamo.config.co
7273
7374
train(model, x)
7475
75-
In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using torch.randn(10).
76+
In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using ``torch.randn(10)``.
7677
We define the training loop function ``train`` and decorate it with @torch.compile to optimize its execution.
7778
When ``train(model, x)`` is called:
7879

79-
* Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``
80-
* Dynamo intercepts the python bytecode, simulates their execution and records the operations into a graph
81-
* AOTDispatcher disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
82-
* Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward
83-
* Dynamo sets the optimized function to be evaluated next by Python Interpreter
84-
* Python Interpreter executes the optimized function, which basically executes ``loss = model(x).sum()``
85-
* Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we enabled the config: ``torch._dynamo.config.compiled_autograd = True``
86-
* Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode
87-
* The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher does not need to partition this graph into a forward and backward
80+
* Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``.
81+
* Dynamo intercepts the Python bytecode, simulates their execution and records the operations into a graph.
82+
* ``AOTDispatcher`` disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
83+
* Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward.
84+
* Dynamo sets the optimized function to be evaluated next by Python Interpreter.
85+
* Python Interpreter executes the optimized function, which executes ``loss = model(x).sum()``.
86+
* Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we set ``torch._dynamo.config.compiled_autograd = True``.
87+
* Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this process, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully-traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode.
88+
* The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher will not need to partition the graph.
8889

8990
Inspecting the compiled autograd logs
9091
-------------------------------------
@@ -180,7 +181,7 @@ Or you can use the context manager, which will apply to all autograd calls withi
180181
181182
Compiled Autograd addresses certain limitations of AOTAutograd
182183
--------------------------------------------------------------
183-
1. Graph breaks in the forward lead to graph breaks in the backward
184+
1. Graph breaks in the forward pass lead to graph breaks in the backward pass:
184185

185186
.. code:: python
186187
@@ -248,7 +249,7 @@ There should be a ``call_hook`` node in the graph, which dynamo will later inlin
248249
249250
Common recompilation reasons for Compiled Autograd
250251
--------------------------------------------------
251-
1. Due to change in autograd structure
252+
1. Due to changes in the autograd structure of the loss value
252253

253254
.. code:: python
254255
@@ -258,7 +259,7 @@ Common recompilation reasons for Compiled Autograd
258259
loss = op(x, x).sum()
259260
torch.compile(lambda: loss.backward(), backend="eager")()
260261
261-
You should see some recompile messages: **Cache miss due to new autograd node**.
262+
In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**.
262263

263264
.. code:: python
264265
@@ -273,7 +274,7 @@ You should see some recompile messages: **Cache miss due to new autograd node**.
273274
...
274275
"""
275276
276-
2. Due to dynamic shapes
277+
2. Due to tensors changing shapes
277278

278279
.. code:: python
279280
@@ -283,7 +284,7 @@ You should see some recompile messages: **Cache miss due to new autograd node**.
283284
loss = x.sum()
284285
torch.compile(lambda: loss.backward(), backend="eager")()
285286
286-
You should see some recompiles messages: **Cache miss due to changed shapes**.
287+
In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**.
287288

288289
.. code:: python
289290
@@ -298,4 +299,4 @@ You should see some recompiles messages: **Cache miss due to changed shapes**.
298299
299300
Conclusion
300301
----------
301-
In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
302+
In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons. Stay tuned for deep dives on `dev-discuss <https://dev-discuss.pytorch.org/>`_.

0 commit comments

Comments
 (0)