Skip to content

Commit 811211f

Browse files
committed
Initial commit
1 parent 3469d47 commit 811211f

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,13 @@ Welcome to PyTorch Tutorials
429429
:link: intermediate/custom_function_double_backward_tutorial.html
430430
:tags: Extending-PyTorch,Frontend-APIs
431431

432+
.. customcarditem::
433+
:header: Control Flow Operator Tutorial
434+
:card_description: Native control flow with torch.compile, make complex models first-class citizens in PyTorch.
435+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
436+
:link: intermediate_source/torch_controlflow_tutorial.html
437+
:tags: Model-Optimization
438+
432439
.. customcarditem::
433440
:header: Custom Function Tutorial: Fusing Convolution and Batch Norm
434441
:card_description: Learn how to create a custom autograd Function that fuses batch norm into a convolution to improve memory usage.
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Tutorial of control flow operators
4+
========================================
5+
**Authors:** Yidi Wu, Thomas Ortner, Richard Zou, Edward Yang, Adnan Akhundov, Horace He and Yanan Cao
6+
7+
This tutorial introduces the PyTorch Control Flow Operators: ``cond``, ``while_loop``,
8+
``scan``, ``associative_scan``, and ``map``. These operators enable data-dependent
9+
control flow to be expressed in a functional, differentiable, and exportable
10+
manner. The tutorial is split into two parts:
11+
12+
Part 1: Inference Examples
13+
--------------------------
14+
Demonstrates basic usage of each control flow operator, following the examples
15+
from the paper.
16+
17+
Part 2: Autograd and Differentiation
18+
------------------------------------
19+
Shows how PyTorch's autograd integrates with the control flow operators and how
20+
to compute gradients through them.
21+
22+
References:
23+
- Control flow operator paper (for semantics and detailed implementation notes)
24+
- Template for documentation structure (torch.export tutorial)
25+
26+
Note: The control flow operators are experimental as of PyTorch 2.9 and are
27+
subject to change.
28+
"""
29+
30+
import torch
31+
from torch.export import export
32+
33+
try:
34+
from functorch.experimental.control_flow import cond
35+
except Exception:
36+
cond = getattr(torch, "cond", None)
37+
38+
from torch._higher_order_ops.map import map as torch_map
39+
from torch._higher_order_ops.scan import scan
40+
from torch._higher_order_ops.associative_scan import associative_scan
41+
from torch._higher_order_ops.while_loop import while_loop
42+
43+
################################################################################
44+
# Part 1: Inference Examples
45+
# ==========================
46+
#
47+
# This section demonstrates the use of control flow operators for inference.
48+
# Each example corresponds to an operator introduced in the paper.
49+
################################################################################
50+
51+
######################################################################
52+
# cond — data-dependent branching
53+
# -------------------------------
54+
#
55+
# The ``cond`` operator performs a data-dependent branch that can be traced and
56+
# exported. Both branches must have the same input and output structure.
57+
######################################################################
58+
59+
class CondExample(torch.nn.Module):
60+
def forward(self, x: torch.Tensor):
61+
pred = (x.sum() > 0).unsqueeze(0)
62+
63+
def true_fn(t: torch.Tensor):
64+
return (t.cos(),)
65+
66+
def false_fn(t: torch.Tensor):
67+
return (t.sin(),)
68+
69+
out = cond(pred, true_fn, false_fn, (x,))
70+
return out[0]
71+
72+
73+
x = torch.randn(3, 3)
74+
model = CondExample()
75+
print("cond result:\n", model(x))
76+
77+
exported = export(model, (x,))
78+
print("Exported graph for cond:\n", exported.graph)
79+
80+
######################################################################
81+
# while_loop — iterative computation with a stopping condition
82+
# ------------------------------------------------------------
83+
#
84+
# The ``while_loop`` operator executes a body function repeatedly while a condition
85+
# is met. Both condition and body must preserve the structure of the carry.
86+
######################################################################
87+
88+
class CountdownExample(torch.nn.Module):
89+
def forward(self, n: torch.Tensor):
90+
def cond_fn(i):
91+
return i > 0
92+
93+
def body_fn(i):
94+
return i - 1
95+
96+
(res,) = while_loop(cond_fn, body_fn, (n,))
97+
return res
98+
99+
100+
n = torch.tensor(5)
101+
countdown = CountdownExample()
102+
print("while_loop result:\n", countdown(n))
103+
104+
######################################################################
105+
# scan — sequential accumulation
106+
# ------------------------------
107+
#
108+
# The ``scan`` operator performs a for-loop style computation and returns both the
109+
# final carry and stacked outputs per iteration.
110+
######################################################################
111+
112+
def combine(carry, x):
113+
new_carry = carry + x
114+
out = new_carry
115+
return new_carry, out
116+
117+
xs = torch.tensor([1.0, 2.0, 3.0, 4.0])
118+
init = torch.tensor(0.0)
119+
carry, outs = scan(combine, init, xs, dim=0)
120+
print("scan cumulative result:\n", outs)
121+
122+
######################################################################
123+
# associative_scan — parallel prefix computation
124+
# ----------------------------------------------
125+
#
126+
# The ``associative_scan`` operator performs an associative accumulation such as a
127+
# prefix product in a parallelizable way.
128+
######################################################################
129+
130+
def mul(a, b):
131+
return a * b
132+
133+
vals = torch.arange(1.0, 6.0)
134+
res = associative_scan(mul, vals, dim=0, combine_mode="pointwise")
135+
print("associative_scan cumulative products:\n", res)
136+
137+
######################################################################
138+
# map — functional iteration over a leading dimension
139+
# ---------------------------------------------------
140+
#
141+
# The ``map`` operator applies a function to slices of its input along the leading
142+
# dimension, stacking the results.
143+
######################################################################
144+
145+
def body_fn(x, y):
146+
return x + y
147+
148+
xs = torch.ones(4, 3)
149+
y = torch.tensor(5.0)
150+
result = torch_map(body_fn, xs, y)
151+
print("map result:\n", result)
152+
153+
################################################################################
154+
# Part 2: Autograd and Differentiation
155+
# ====================================
156+
#
157+
# This section shows how control flow operators integrate with PyTorch’s autograd.
158+
# The same operators can be used in differentiable computations.
159+
################################################################################
160+
161+
######################################################################
162+
# Gradients through map
163+
# ---------------------
164+
#
165+
# All control flow operators are differentiable if the operations inside them are.
166+
# Here we compute gradients through a ``map`` call.
167+
######################################################################
168+
169+
def differentiable_body(x, y):
170+
return x.sin() * y.cos()
171+
172+
xs = torch.randn(3, 4, requires_grad=True)
173+
y = torch.randn(4, requires_grad=True)
174+
175+
out = torch_map(differentiable_body, xs, y)
176+
loss = out.sum()
177+
loss.backward()
178+
179+
print("Gradient wrt xs:\n", xs.grad)
180+
print("Gradient wrt y:\n", y.grad)
181+
182+
######################################################################
183+
# Differentiable scan (RNN-style)
184+
# -------------------------------
185+
#
186+
# Gradients can also propagate through a ``scan`` operation where the carry
187+
# represents a hidden state.
188+
######################################################################
189+
190+
def rnn_combine(carry, x):
191+
h = torch.tanh(carry + x)
192+
return h, h
193+
194+
xs = torch.randn(4, 3, requires_grad=True)
195+
init = torch.zeros(3, requires_grad=True)
196+
carry, outs = scan(rnn_combine, init, xs, dim=0)
197+
loss = outs.sum()
198+
loss.backward()
199+
print("Gradient wrt xs:\n", xs.grad)
200+
print("Gradient wrt init:\n", init.grad)
201+
202+
################################################################################
203+
# Conclusion
204+
# ----------
205+
#
206+
# The PyTorch control flow operators enable flexible, differentiable, and
207+
# exportable control flow directly in Python. The main takeaways from the paper
208+
# are:
209+
#
210+
# 1. **Unified semantics**: Each operator has clearly defined input/output rules
211+
# and pytree invariants that ensure compatibility with ``torch.export``.
212+
# 2. **Differentiability**: Operators like ``map``, ``scan``, and ``cond`` support
213+
# full autograd propagation, allowing seamless integration with gradient-based
214+
# methods.
215+
# 3. **Exportability**: Because they are implemented as functional ops, control
216+
# flow constructs can be traced, serialized, and optimized like standard ops.
217+
# 4. **Efficiency and parallelism**: Operators such as ``associative_scan`` allow
218+
# parallel prefix computation, unlocking performance gains.
219+
# 5. **Structured control flow**: ``cond`` and ``while_loop`` generalize
220+
# conditional and iterative logic while preserving graph structure and
221+
# analyzability.
222+
#
223+
# These operators bridge the gap between dynamic Python control flow and static
224+
# computation graphs, providing a powerful foundation for defining models with
225+
# complex or data-dependent behaviors in PyTorch.
226+
################################################################################

0 commit comments

Comments
 (0)