Skip to content

Commit 65d011d

Browse files
authored
Merge branch 'pytorch:main' into main
2 parents eec7ebf + 26066b7 commit 65d011d

File tree

6 files changed

+305
-4
lines changed

6 files changed

+305
-4
lines changed

.ci/docker/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ datasets
3636
transformers
3737
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
3838
onnx
39-
onnxscript
39+
onnxscript>=0.2.2
4040
onnxruntime
4141
evaluate
4242
accelerate>=0.20.1
@@ -69,5 +69,5 @@ pycocotools
6969
semilearn==0.3.2
7070
torchao==0.5.0
7171
segment_anything==1.0
72-
torchrec==1.0.0; platform_system == "Linux"
72+
torchrec==1.1.0; platform_system == "Linux"
7373
fbgemm-gpu==1.1.0; platform_system == "Linux"

.jenkins/build.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ sudo apt-get install -y pandoc
2626
# sudo pip3 install torch==2.6.0 torchvision --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu124
2727
# sudo pip uninstall -y fbgemm-gpu torchrec
2828
# sudo pip3 install fbgemm-gpu==1.1.0 torchrec==1.0.0 --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu124
29-
29+
sudo pip uninstall -y torch torchvision torchaudio torchtext torchdata torchrl tensordict
30+
pip3 install torch==2.7.0 torchvision torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu126
31+
#sudo pip uninstall -y fbgemm-gpu
3032
# Install two language tokenizers for Translation with TorchText tutorial
3133
python -m spacy download en_core_web_sm
3234
python -m spacy download de_core_news_sm

.jenkins/validate_tutorials_built.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@
5151
"intermediate_source/text_to_speech_with_torchaudio",
5252
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
5353
"advanced_source/semi_structured_sparse", # reenable after 3303 is fixed.
54-
"recipes_source/recipes/reasoning_about_shapes"
54+
"intermediate_source/mario_rl_tutorial", # reenable after 3302 is fixed
55+
"intermediate_source/reinforcement_ppo", # reenable after 3302 is fixed
56+
"intermediate_source/pinmem_nonblock", # reenable after 3302 is fixed
57+
"intermediate_source/dqn_with_rnn_tutorial", # reenable after 3302 is fixed
58+
"advanced_source/pendulum", # reenable after 3302 is fixed
59+
"advanced_source/coding_ddpg", # reenable after 3302 is fixed
60+
"intermediate_source/torchrec_intro_tutorial", # reenable after 3302 is fixed
61+
"recipes_source/recipes/reasoning_about_shapes" # reenable after 3326 is fixed
5562
]
5663

5764
def tutorial_source_dirs() -> List[Path]:

recipes_source/foreach_map.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""
2+
(beta) Explicit horizontal fusion with foreach_map and torch.compile
3+
============================================================
4+
5+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
6+
"""
7+
8+
#########################################################
9+
# Horizontal fusion is a key optimization in ML compilers. In eager,
10+
# this is typically expressed using the torch._foreach* ops which parallelizes
11+
# operations across a list of tensors. However, supporting all possible permutations
12+
# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map
13+
# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach
14+
# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer
15+
# with ``foreach_map`` to generate a fully fused kernel.
16+
#
17+
#
18+
# .. note::
19+
#
20+
# This tutorial requires PyTorch 2.7.0 or later.
21+
22+
#####################################################################
23+
# Model Setup
24+
# ~~~~~~~~~~~~~~~~~~~~~
25+
# For this example, we'll use a simple sequence of linear layers.
26+
# We instantiate an independent copy to compare the two optimizer implementations.
27+
#
28+
import torch
29+
30+
# exit cleanly if we are on a device that doesn't support ``torch.compile``
31+
if torch.cuda.get_device_capability() < (7, 0):
32+
print("Exiting because torch.compile is not supported on this device.")
33+
import sys
34+
sys.exit(0)
35+
36+
# Create simple model
37+
model = torch.nn.Sequential(
38+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
39+
)
40+
model_copy = torch.nn.Sequential(
41+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
42+
)
43+
input = torch.rand(1024, device="cuda")
44+
45+
# run forward pass
46+
output = model(input)
47+
output_copy = model_copy(input)
48+
49+
# run backward to populate the grads for our optimizer below
50+
output.sum().backward()
51+
output_copy.sum().backward()
52+
53+
#####################################################################
54+
# Helper functions for foreach_map implementation
55+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
56+
#
57+
# In this section, we'll begin our implementation of the Adam optimizer.
58+
#
59+
from torch._higher_order_ops.foreach_map import foreach_map
60+
61+
# Helper function to extract optimizer states from a torch.optim.Adam instance
62+
def get_inputs(optim):
63+
steps = []
64+
params = []
65+
grads = []
66+
exp_avgs = []
67+
exp_avg_sqs = []
68+
for group in optim.param_groups:
69+
for p in group["params"]:
70+
params.append(p)
71+
grads.append(p.grad)
72+
state = optim.state[p]
73+
exp_avgs.append(state["exp_avg"])
74+
exp_avg_sqs.append(state["exp_avg_sq"])
75+
steps.append(state["step"])
76+
77+
return steps, params, exp_avgs, exp_avg_sqs
78+
79+
80+
# Functions to update the different optimizer states
81+
def update_exp_avg_sq(exp_avg_sq, grad, beta2):
82+
return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)
83+
84+
def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):
85+
bias_correction1 = 1 - torch.pow(beta1, step)
86+
bias_correction2 = (1 - torch.pow(beta2, step)).sqrt()
87+
step_size = (lr / bias_correction1).neg()
88+
denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)
89+
return torch.add(param, torch.div(exp_avg, denom))
90+
91+
# Our full Adam implementation
92+
def foreach_map_adam(
93+
steps,
94+
params,
95+
exp_avgs,
96+
exp_avg_sqs,
97+
weight_decay=0,
98+
beta1=0.9,
99+
beta2=0.999,
100+
lr=1e-3,
101+
eps=1e-8,
102+
):
103+
with torch.no_grad():
104+
grads = [param.grad for param in params]
105+
# update step
106+
updated_steps = foreach_map(lambda x: x + 1, steps)
107+
torch._foreach_copy_(steps, updated_steps)
108+
109+
if weight_decay != 0:
110+
foreach_map(torch.add, (grads,), alpha=weight_decay)
111+
112+
# Higher-order operators (HOPs) cannot have multiple outputs at the moment
113+
# need to call foreach_map once for each output
114+
exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1)
115+
exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2)
116+
params_updated = foreach_map(
117+
update_param,
118+
params,
119+
steps,
120+
exp_avgs_updated,
121+
exp_avgs_sq_updated,
122+
beta1,
123+
beta2,
124+
lr,
125+
eps,
126+
)
127+
# Higher-order operators (HOPs) don't support input mutation today
128+
# so manually update the states in-place
129+
torch._foreach_copy_(exp_avgs, exp_avgs_updated)
130+
torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated)
131+
torch._foreach_copy_(params, params_updated)
132+
return
133+
134+
#####################################################################
135+
# Setting up and running the compiled kernel
136+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137+
#
138+
# In this section, we'll run our Adam optimizer
139+
# and compare the results
140+
#
141+
# .. note::
142+
#
143+
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
144+
opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
145+
opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01))
146+
147+
# warm up the optimizer state dict
148+
opt_eager.step()
149+
opt_eager_copy.step()
150+
151+
inputs = get_inputs(opt_eager_copy)
152+
compiled_adam = torch.compile(foreach_map_adam)
153+
154+
# optionally view the output code
155+
torch._logging.set_logs(output_code=True)
156+
157+
# Warmup runs to compile the function
158+
for _ in range(5):
159+
opt_eager.step()
160+
compiled_adam(*inputs)
161+
162+
for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]):
163+
torch.allclose(eager_p, compile_p)
164+
165+
# Benchmark performance
166+
167+
# Let's define a helpful benchmarking function:
168+
import torch.utils.benchmark as benchmark
169+
170+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
171+
t0 = benchmark.Timer(
172+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
173+
)
174+
return t0.blocked_autorange().mean * 1e6
175+
176+
eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step)
177+
compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs))
178+
179+
assert eager_runtime > compiled_runtime
180+
181+
print(f"eager runtime: {eager_runtime}us")
182+
print(f"compiled runtime: {compiled_runtime}us")
183+
184+
185+
186+
######################################################################
187+
# Conclusion
188+
# ~~~~~~~~~~
189+
# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map.
190+
# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam
191+
# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide
192+
# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a
193+
# valuable resource for developers looking to improve the performance of their models with horizontal fusion.
194+
#
195+
# See also:
196+
#
197+
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
198+
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.

recipes_source/recipes_index.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
317317
:link: ../recipes/amx.html
318318
:tags: Model-Optimization
319319

320+
.. (beta) Utilizing Torch Function modes with torch.compile
321+
322+
.. customcarditem::
323+
:header: (beta) Utilizing Torch Function modes with torch.compile
324+
:card_description: Override torch operators with Torch Function modes and torch.compile
325+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
326+
:link: ../recipes/torch_compile_torch_function_modes.html
327+
:tags: Model-Optimization
328+
320329
.. (beta) Compiling the Optimizer with torch.compile
321330
322331
.. customcarditem::
@@ -335,6 +344,14 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
335344
:link: ../recipes/compiling_optimizer_lr_scheduler.html
336345
:tags: Model-Optimization
337346

347+
.. (beta) Explicit horizontal fusion with foreach_map and torch.compile
348+
.. customcarditem::
349+
:header: (beta) Explicit horizontal fusion with foreach_map and torch.compile
350+
:card_description: Horizontally fuse pointwise ops with torch.compile
351+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
352+
:link: ../recipes/foreach_map.py
353+
:tags: Model-Optimization
354+
338355
.. Using User-Defined Triton Kernels with ``torch.compile``
339356
340357
.. customcarditem::
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
(beta) Utilizing Torch Function modes with torch.compile
3+
============================================================
4+
5+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
6+
"""
7+
8+
#########################################################
9+
# This recipe covers how to use a key torch extensibility point,
10+
# torch function modes, in tandem with ``torch.compile`` to override
11+
# the behavior of torch operators, also know as **ops**, at trace time, with no runtime overhead.
12+
#
13+
# .. note::
14+
#
15+
# This recipe requires PyTorch 2.7.0 or later.
16+
17+
18+
#####################################################################
19+
# Rewriting a torch op (torch.add -> torch.mul)
20+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21+
# For this example, we'll use torch function modes to rewrite occurences
22+
# of addition with multiply instead. This type of override can be common
23+
# if a certain backend has a custom implementation that should be dispatched
24+
# for a given op.
25+
import torch
26+
27+
# exit cleanly if we are on a device that doesn't support ``torch.compile``
28+
if torch.cuda.get_device_capability() < (7, 0):
29+
print("Exiting because torch.compile is not supported on this device.")
30+
import sys
31+
sys.exit(0)
32+
33+
from torch.overrides import BaseTorchFunctionMode
34+
35+
# Define our mode, Note: ``BaseTorchFunctionMode``
36+
# implements the actual invocation of func(..)
37+
class AddToMultiplyMode(BaseTorchFunctionMode):
38+
def __torch_function__(self, func, types, args=(), kwargs=None):
39+
if func == torch.Tensor.add:
40+
func = torch.mul
41+
42+
return super().__torch_function__(func, types, args, kwargs)
43+
44+
@torch.compile()
45+
def test_fn(x, y):
46+
return x + y * x # Note: infix operators map to torch.Tensor.* methods
47+
48+
x = torch.rand(2, 2)
49+
y = torch.rand_like(x)
50+
51+
with AddToMultiplyMode():
52+
z = test_fn(x, y)
53+
54+
assert torch.allclose(z, x * y * x)
55+
56+
# The mode can also be used within the compiled region as well like this:
57+
58+
@torch.compile()
59+
def test_fn(x, y):
60+
with AddToMultiplyMode():
61+
return x + y * x # Note: infix operators map to torch.Tensor.* methods
62+
63+
x = torch.rand(2, 2)
64+
y = torch.rand_like(x)
65+
z = test_fn(x, y)
66+
67+
assert torch.allclose(z, x * y * x)
68+
69+
######################################################################
70+
# Conclusion
71+
# ~~~~~~~~~~
72+
# In this recipe we demonstrated how to override the behavior of ``torch.*`` operators
73+
# using torch function modes from within ``torch.compile``. This enables users to utilize
74+
# the extensibility benefits of torch function modes without the runtime overhead
75+
# of calling torch function on every op invocation.
76+
#
77+
# * See `Extending Torch API with Modes <https://pytorch.org/docs/stable/notes/extending.html#extending-all-torch-api-with-modes>`__ for other examples and background on Torch Function modes.

0 commit comments

Comments
 (0)