You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/compiled_autograd_tutorial.rst
+19-18Lines changed: 19 additions & 18 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,6 +16,7 @@ Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
16
16
17
17
* PyTorch 2.4
18
18
* Complete the `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
19
+
* Read through the TorchDynamo and AOTAutograd sections of `Get Started with PyTorch 2.x <https://pytorch.org/get-started/pytorch-2.0/>`_
19
20
20
21
Overview
21
22
--------
@@ -41,7 +42,7 @@ However, Compiled Autograd introduces its own limitations:
41
42
Setup
42
43
-----
43
44
In this tutorial, we will base our examples on this simple neural network model.
44
-
It takes a a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
45
+
It takes a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
45
46
46
47
.. code:: python
47
48
@@ -57,7 +58,7 @@ It takes a a 10-dimensional input vector, processes it through a single linear l
57
58
58
59
Basic usage
59
60
------------
60
-
Before calling the torch.compile API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
61
+
Before calling the ``torch.compile`` API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
61
62
62
63
.. code:: python
63
64
@@ -72,19 +73,19 @@ Before calling the torch.compile API, make sure to set ``torch._dynamo.config.co
72
73
73
74
train(model, x)
74
75
75
-
In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using torch.randn(10).
76
+
In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using ``torch.randn(10)``.
76
77
We define the training loop function ``train`` and decorate it with @torch.compile to optimize its execution.
77
78
When ``train(model, x)`` is called:
78
79
79
-
* Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``
80
-
* Dynamo intercepts the python bytecode, simulates their execution and records the operations into a graph
81
-
* AOTDispatcher disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
82
-
* Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward
83
-
* Dynamo sets the optimized function to be evaluated next by Python Interpreter
84
-
* Python Interpreter executes the optimized function, which basically executes ``loss = model(x).sum()``
85
-
* Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we enabled the config: ``torch._dynamo.config.compiled_autograd = True``
86
-
* Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fullytraced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode
87
-
* The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher does not need to partition this graph into a forward and backward
80
+
* Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``.
81
+
* Dynamo intercepts the Python bytecode, simulates their execution and records the operations into a graph.
82
+
* ``AOTDispatcher`` disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
83
+
* Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward.
84
+
* Dynamo sets the optimized function to be evaluated next by Python Interpreter.
85
+
* Python Interpreter executes the optimized function, which executes ``loss = model(x).sum()``.
86
+
* Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we set ``torch._dynamo.config.compiled_autograd = True``.
87
+
* Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this process, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully-traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode.
88
+
* The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher will not need to partition the graph.
88
89
89
90
Inspecting the compiled autograd logs
90
91
-------------------------------------
@@ -180,7 +181,7 @@ Or you can use the context manager, which will apply to all autograd calls withi
180
181
181
182
Compiled Autograd addresses certain limitations of AOTAutograd
You should see some recompile messages: **Cache miss due to new autograd node**.
262
+
In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**.
262
263
263
264
.. code:: python
264
265
@@ -273,7 +274,7 @@ You should see some recompile messages: **Cache miss due to new autograd node**.
273
274
...
274
275
"""
275
276
276
-
2. Due to dynamic shapes
277
+
2. Due to tensors changing shapes
277
278
278
279
.. code:: python
279
280
@@ -283,7 +284,7 @@ You should see some recompile messages: **Cache miss due to new autograd node**.
You should see some recompiles messages: **Cache miss due to changed shapes**.
287
+
In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**.
287
288
288
289
.. code:: python
289
290
@@ -298,4 +299,4 @@ You should see some recompiles messages: **Cache miss due to changed shapes**.
298
299
299
300
Conclusion
300
301
----------
301
-
In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
302
+
In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons. Stay tuned for deep dives on `dev-discuss <https://dev-discuss.pytorch.org/>`_.
0 commit comments