|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Tutorial of control flow operators |
| 4 | +======================================== |
| 5 | +**Authors:** Yidi Wu, Thomas Ortner, Richard Zou, Edward Yang, Adnan Akhundov, Horace He and Yanan Cao |
| 6 | +
|
| 7 | +This tutorial introduces the PyTorch Control Flow Operators: ``cond``, ``while_loop``, |
| 8 | +``scan``, ``associative_scan``, and ``map``. These operators enable data-dependent |
| 9 | +control flow to be expressed in a functional, differentiable, and exportable |
| 10 | +manner. The tutorial is split into two parts: |
| 11 | +
|
| 12 | +Part 1: Inference Examples |
| 13 | +-------------------------- |
| 14 | +Demonstrates basic usage of each control flow operator, following the examples |
| 15 | +from the paper. |
| 16 | +
|
| 17 | +Part 2: Autograd and Differentiation |
| 18 | +------------------------------------ |
| 19 | +Shows how PyTorch's autograd integrates with the control flow operators and how |
| 20 | +to compute gradients through them. |
| 21 | +
|
| 22 | +References: |
| 23 | +- Control flow operator paper (for semantics and detailed implementation notes) |
| 24 | +- Template for documentation structure (torch.export tutorial) |
| 25 | +
|
| 26 | +Note: The control flow operators are experimental as of PyTorch 2.9 and are |
| 27 | +subject to change. |
| 28 | +""" |
| 29 | + |
| 30 | +import torch |
| 31 | +from torch.export import export |
| 32 | + |
| 33 | +try: |
| 34 | + from functorch.experimental.control_flow import cond |
| 35 | +except Exception: |
| 36 | + cond = getattr(torch, "cond", None) |
| 37 | + |
| 38 | +from torch._higher_order_ops.map import map as torch_map |
| 39 | +from torch._higher_order_ops.scan import scan |
| 40 | +from torch._higher_order_ops.associative_scan import associative_scan |
| 41 | +from torch._higher_order_ops.while_loop import while_loop |
| 42 | + |
| 43 | +################################################################################ |
| 44 | +# Part 1: Inference Examples |
| 45 | +# ========================== |
| 46 | +# |
| 47 | +# This section demonstrates the use of control flow operators for inference. |
| 48 | +# Each example corresponds to an operator introduced in the paper. |
| 49 | +################################################################################ |
| 50 | + |
| 51 | +###################################################################### |
| 52 | +# cond — data-dependent branching |
| 53 | +# ------------------------------- |
| 54 | +# |
| 55 | +# The ``cond`` operator performs a data-dependent branch that can be traced and |
| 56 | +# exported. Both branches must have the same input and output structure. |
| 57 | +###################################################################### |
| 58 | + |
| 59 | +class CondExample(torch.nn.Module): |
| 60 | + def forward(self, x: torch.Tensor): |
| 61 | + pred = (x.sum() > 0).unsqueeze(0) |
| 62 | + |
| 63 | + def true_fn(t: torch.Tensor): |
| 64 | + return (t.cos(),) |
| 65 | + |
| 66 | + def false_fn(t: torch.Tensor): |
| 67 | + return (t.sin(),) |
| 68 | + |
| 69 | + out = cond(pred, true_fn, false_fn, (x,)) |
| 70 | + return out[0] |
| 71 | + |
| 72 | + |
| 73 | +x = torch.randn(3, 3) |
| 74 | +model = CondExample() |
| 75 | +print("cond result:\n", model(x)) |
| 76 | + |
| 77 | +exported = export(model, (x,)) |
| 78 | +print("Exported graph for cond:\n", exported.graph) |
| 79 | + |
| 80 | +###################################################################### |
| 81 | +# while_loop — iterative computation with a stopping condition |
| 82 | +# ------------------------------------------------------------ |
| 83 | +# |
| 84 | +# The ``while_loop`` operator executes a body function repeatedly while a condition |
| 85 | +# is met. Both condition and body must preserve the structure of the carry. |
| 86 | +###################################################################### |
| 87 | + |
| 88 | +class CountdownExample(torch.nn.Module): |
| 89 | + def forward(self, n: torch.Tensor): |
| 90 | + def cond_fn(i): |
| 91 | + return i > 0 |
| 92 | + |
| 93 | + def body_fn(i): |
| 94 | + return i - 1 |
| 95 | + |
| 96 | + (res,) = while_loop(cond_fn, body_fn, (n,)) |
| 97 | + return res |
| 98 | + |
| 99 | + |
| 100 | +n = torch.tensor(5) |
| 101 | +countdown = CountdownExample() |
| 102 | +print("while_loop result:\n", countdown(n)) |
| 103 | + |
| 104 | +###################################################################### |
| 105 | +# scan — sequential accumulation |
| 106 | +# ------------------------------ |
| 107 | +# |
| 108 | +# The ``scan`` operator performs a for-loop style computation and returns both the |
| 109 | +# final carry and stacked outputs per iteration. |
| 110 | +###################################################################### |
| 111 | + |
| 112 | +def combine(carry, x): |
| 113 | + new_carry = carry + x |
| 114 | + out = new_carry |
| 115 | + return new_carry, out |
| 116 | + |
| 117 | +xs = torch.tensor([1.0, 2.0, 3.0, 4.0]) |
| 118 | +init = torch.tensor(0.0) |
| 119 | +carry, outs = scan(combine, init, xs, dim=0) |
| 120 | +print("scan cumulative result:\n", outs) |
| 121 | + |
| 122 | +###################################################################### |
| 123 | +# associative_scan — parallel prefix computation |
| 124 | +# ---------------------------------------------- |
| 125 | +# |
| 126 | +# The ``associative_scan`` operator performs an associative accumulation such as a |
| 127 | +# prefix product in a parallelizable way. |
| 128 | +###################################################################### |
| 129 | + |
| 130 | +def mul(a, b): |
| 131 | + return a * b |
| 132 | + |
| 133 | +vals = torch.arange(1.0, 6.0) |
| 134 | +res = associative_scan(mul, vals, dim=0, combine_mode="pointwise") |
| 135 | +print("associative_scan cumulative products:\n", res) |
| 136 | + |
| 137 | +###################################################################### |
| 138 | +# map — functional iteration over a leading dimension |
| 139 | +# --------------------------------------------------- |
| 140 | +# |
| 141 | +# The ``map`` operator applies a function to slices of its input along the leading |
| 142 | +# dimension, stacking the results. |
| 143 | +###################################################################### |
| 144 | + |
| 145 | +def body_fn(x, y): |
| 146 | + return x + y |
| 147 | + |
| 148 | +xs = torch.ones(4, 3) |
| 149 | +y = torch.tensor(5.0) |
| 150 | +result = torch_map(body_fn, xs, y) |
| 151 | +print("map result:\n", result) |
| 152 | + |
| 153 | +################################################################################ |
| 154 | +# Part 2: Autograd and Differentiation |
| 155 | +# ==================================== |
| 156 | +# |
| 157 | +# This section shows how control flow operators integrate with PyTorch’s autograd. |
| 158 | +# The same operators can be used in differentiable computations. |
| 159 | +################################################################################ |
| 160 | + |
| 161 | +###################################################################### |
| 162 | +# Gradients through map |
| 163 | +# --------------------- |
| 164 | +# |
| 165 | +# All control flow operators are differentiable if the operations inside them are. |
| 166 | +# Here we compute gradients through a ``map`` call. |
| 167 | +###################################################################### |
| 168 | + |
| 169 | +def differentiable_body(x, y): |
| 170 | + return x.sin() * y.cos() |
| 171 | + |
| 172 | +xs = torch.randn(3, 4, requires_grad=True) |
| 173 | +y = torch.randn(4, requires_grad=True) |
| 174 | + |
| 175 | +out = torch_map(differentiable_body, xs, y) |
| 176 | +loss = out.sum() |
| 177 | +loss.backward() |
| 178 | + |
| 179 | +print("Gradient wrt xs:\n", xs.grad) |
| 180 | +print("Gradient wrt y:\n", y.grad) |
| 181 | + |
| 182 | +###################################################################### |
| 183 | +# Differentiable scan (RNN-style) |
| 184 | +# ------------------------------- |
| 185 | +# |
| 186 | +# Gradients can also propagate through a ``scan`` operation where the carry |
| 187 | +# represents a hidden state. |
| 188 | +###################################################################### |
| 189 | + |
| 190 | +def rnn_combine(carry, x): |
| 191 | + h = torch.tanh(carry + x) |
| 192 | + return h, h |
| 193 | + |
| 194 | +xs = torch.randn(4, 3, requires_grad=True) |
| 195 | +init = torch.zeros(3, requires_grad=True) |
| 196 | +carry, outs = scan(rnn_combine, init, xs, dim=0) |
| 197 | +loss = outs.sum() |
| 198 | +loss.backward() |
| 199 | +print("Gradient wrt xs:\n", xs.grad) |
| 200 | +print("Gradient wrt init:\n", init.grad) |
| 201 | + |
| 202 | +################################################################################ |
| 203 | +# Conclusion |
| 204 | +# ---------- |
| 205 | +# |
| 206 | +# The PyTorch control flow operators enable flexible, differentiable, and |
| 207 | +# exportable control flow directly in Python. The main takeaways from the paper |
| 208 | +# are: |
| 209 | +# |
| 210 | +# 1. **Unified semantics**: Each operator has clearly defined input/output rules |
| 211 | +# and pytree invariants that ensure compatibility with ``torch.export``. |
| 212 | +# 2. **Differentiability**: Operators like ``map``, ``scan``, and ``cond`` support |
| 213 | +# full autograd propagation, allowing seamless integration with gradient-based |
| 214 | +# methods. |
| 215 | +# 3. **Exportability**: Because they are implemented as functional ops, control |
| 216 | +# flow constructs can be traced, serialized, and optimized like standard ops. |
| 217 | +# 4. **Efficiency and parallelism**: Operators such as ``associative_scan`` allow |
| 218 | +# parallel prefix computation, unlocking performance gains. |
| 219 | +# 5. **Structured control flow**: ``cond`` and ``while_loop`` generalize |
| 220 | +# conditional and iterative logic while preserving graph structure and |
| 221 | +# analyzability. |
| 222 | +# |
| 223 | +# These operators bridge the gap between dynamic Python control flow and static |
| 224 | +# computation graphs, providing a powerful foundation for defining models with |
| 225 | +# complex or data-dependent behaviors in PyTorch. |
| 226 | +################################################################################ |
0 commit comments