Skip to content

Commit 4b0fede

Browse files
committed
Create torch_compile_conv_bn_fuser tutorial adapted from fx_conv_bn_fuser
1 parent c1cd7ab commit 4b0fede

File tree

4 files changed

+156
-131
lines changed

4 files changed

+156
-131
lines changed

.jenkins/validate_tutorials_built.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"beginner_source/examples_autograd/polynomial_autograd",
2424
"beginner_source/examples_autograd/polynomial_custom_function",
2525
"intermediate_source/mnist_train_nas", # used by ax_multiobjective_nas_tutorial.py
26-
"intermediate_source/fx_conv_bn_fuser",
26+
"intermediate_source/torch_compile_conv_bn_fuser",
2727
"intermediate_source/_torch_export_nightly_tutorial", # does not work on release
2828
"advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker
2929
"prototype_source/fx_graph_mode_ptq_dynamic",

index.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,6 @@ Welcome to PyTorch Tutorials
348348

349349
.. Code Transformations with FX
350350
351-
.. customcarditem::
352-
:header: Building a Convolution/Batch Norm fuser in FX
353-
:card_description: Build a simple FX pass that fuses batch norm into convolution to improve performance during inference.
354-
:image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png
355-
:link: intermediate/fx_conv_bn_fuser.html
356-
:tags: FX
357-
358351
.. customcarditem::
359352
:header: Building a Simple Performance Profiler with FX
360353
:card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics
@@ -583,6 +576,13 @@ Welcome to PyTorch Tutorials
583576
:link: intermediate/torch_compile_tutorial.html
584577
:tags: Model-Optimization
585578

579+
.. customcarditem::
580+
:header: Building a Convolution/Batch Norm fuser in torch.compile
581+
:card_description: Build a simple pattern matcher pass that fuses batch norm into convolution to improve performance during inference.
582+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
583+
:link: intermediate/torch_compile_conv_bn_fuser.html
584+
:tags: Model-Optimization
585+
586586
.. customcarditem::
587587
:header: Inductor CPU Backend Debugging and Profiling
588588
:card_description: Learn the usage, debugging and performance profiling for ``torch.compile`` with Inductor CPU backend.
@@ -950,7 +950,6 @@ Additional Resources
950950
:hidden:
951951
:caption: Code Transforms with FX
952952

953-
intermediate/fx_conv_bn_fuser
954953
intermediate/fx_profiling_tutorial
955954

956955
.. toctree::
@@ -1001,6 +1000,7 @@ Additional Resources
10011000
intermediate/nvfuser_intro_tutorial
10021001
intermediate/ax_multiobjective_nas_tutorial
10031002
intermediate/torch_compile_tutorial
1003+
intermediate/torch_compile_conv_bn_fuser
10041004
intermediate/compiled_autograd_tutorial
10051005
intermediate/inductor_debug_cpu
10061006
intermediate/scaled_dot_product_attention_tutorial
Lines changed: 145 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
(beta) Building a Convolution/Batch Norm fuser in FX
4-
*******************************************************
5-
**Author**: `Horace He <https://github.com/chillee>`_
3+
Building a Convolution/Batch Norm fuser with torch.compile
4+
******************************************************************
5+
**Author**: `Horace He <https://github.com/chillee>`__, `Will Feng <https://github.com/yf225>`__
66
7-
In this tutorial, we are going to use FX, a toolkit for composable function
8-
transformations of PyTorch, to do the following:
7+
In this tutorial, we are going to use torch.compile and its pattern matching
8+
capabilities to do the following:
99
1010
1) Find patterns of conv/batch norm in the data dependencies.
1111
2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
1212
13-
Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
13+
Note that this specific optimization only works for models in inference mode (i.e. `mode.eval()`).
14+
But the pattern matching system in torch.compile works for both training and inference.
1415
15-
We will be building the fuser that exists here:
16-
https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
16+
We will demonstrate how to register custom fusion patterns with torch.compile's
17+
pattern matcher to optimize model performance.
1718
1819
"""
1920

@@ -24,10 +25,11 @@
2425

2526
from typing import Type, Dict, Any, Tuple, Iterable
2627
import copy
27-
import torch.fx as fx
2828
import torch
2929
import torch.nn as nn
3030

31+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32+
3133
######################################################################
3234
# For this tutorial, we are going to create a model consisting of convolutions
3335
# and batch norms. Note that this model has some tricky components - some of
@@ -61,29 +63,26 @@ def forward(self, x):
6163
x = self.wrapped(x)
6264
return x
6365

64-
model = M()
65-
66+
model = M().to(device)
6667
model.eval()
6768

6869
######################################################################
6970
# Fusing Convolution with Batch Norm
7071
# -----------------------------------------
7172
# One of the primary challenges with trying to automatically fuse convolution
7273
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
73-
# accessing the computational graph. FX resolves this problem by symbolically
74-
# tracing the actual operations called, so that we can track the computations
75-
# through the `forward` call, nested within Sequential modules, or wrapped in
76-
# an user-defined module.
77-
78-
traced_model = torch.fx.symbolic_trace(model)
79-
print(traced_model.graph)
74+
# accessing the computational graph. torch.compile resolves this problem by
75+
# capturing the computational graph during compilation, allowing us to apply
76+
# pattern-based optimizations across the entire model, including operations
77+
# nested within Sequential modules or wrapped in custom modules.
78+
import torch._inductor.pattern_matcher as pm
79+
from torch._inductor.pattern_matcher import register_replacement
8080

8181
######################################################################
82-
# This gives us a graph representation of our model. Note that both the modules
83-
# hidden within the sequential as well as the wrapped Module have been inlined
84-
# into the graph. This is the default level of abstraction, but it can be
85-
# configured by the pass writer. More information can be found at the FX
86-
# overview https://pytorch.org/docs/master/fx.html#module-torch.fx
82+
# torch.compile will capture a graph representation of our model. During
83+
# compilation, modules hidden within Sequential containers and wrapped
84+
# modules are all inlined into the graph, making them available for
85+
# pattern matching and optimization.
8786

8887

8988
####################################
@@ -128,78 +127,74 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
128127

129128

130129
####################################
131-
# FX Fusion Pass
132-
# ----------------------------------
133-
# Now that we have our computational graph as well as a method for fusing
134-
# convolution and batch norm, all that remains is to iterate over the FX graph
135-
# and apply the desired fusions.
136-
137-
138-
def _parent_name(target : str) -> Tuple[str, str]:
139-
"""
140-
Splits a ``qualname`` into parent path and last atom.
141-
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
142-
"""
143-
*parent, name = target.rsplit('.', 1)
144-
return parent[0] if parent else '', name
145-
146-
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
147-
assert(isinstance(node.target, str))
148-
parent_name, name = _parent_name(node.target)
149-
setattr(modules[parent_name], name, new_module)
150-
151-
152-
def fuse(model: torch.nn.Module) -> torch.nn.Module:
153-
model = copy.deepcopy(model)
154-
# The first step of most FX passes is to symbolically trace our model to
155-
# obtain a `GraphModule`. This is a representation of our original model
156-
# that is functionally identical to our original model, except that we now
157-
# also have a graph representation of our forward pass.
158-
fx_model: fx.GraphModule = fx.symbolic_trace(model)
159-
modules = dict(fx_model.named_modules())
160-
161-
# The primary representation for working with FX are the `Graph` and the
162-
# `Node`. Each `GraphModule` has a `Graph` associated with it - this
163-
# `Graph` is also what generates `GraphModule.code`.
164-
# The `Graph` itself is represented as a list of `Node` objects. Thus, to
165-
# iterate through all of the operations in our graph, we iterate over each
166-
# `Node` in our `Graph`.
167-
for node in fx_model.graph.nodes:
168-
# The FX IR contains several types of nodes, which generally represent
169-
# call sites to modules, functions, or methods. The type of node is
170-
# determined by `Node.op`.
171-
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
172-
continue
173-
# For call sites, `Node.target` represents the module/function/method
174-
# that's being called. Here, we check `Node.target` to see if it's a
175-
# batch norm module, and then check `Node.args[0].target` to see if the
176-
# input `Node` is a convolution.
177-
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
178-
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
179-
continue
180-
conv = modules[node.args[0].target]
181-
bn = modules[node.target]
182-
fused_conv = fuse_conv_bn_eval(conv, bn)
183-
replace_node_module(node.args[0], modules, fused_conv)
184-
# As we've folded the batch nor into the conv, we need to replace all uses
185-
# of the batch norm with the conv.
186-
node.replace_all_uses_with(node.args[0])
187-
# Now that all uses of the batch norm have been replaced, we can
188-
# safely remove the batch norm.
189-
fx_model.graph.erase_node(node)
190-
fx_model.graph.lint()
191-
# After we've modified our graph, we need to recompile our graph in order
192-
# to keep the generated code in sync.
193-
fx_model.recompile()
194-
return fx_model
130+
# Pattern Matching with torch.compile
131+
# ------------------------------------
132+
# Now that we have our fusion logic, we need to register a pattern that
133+
# torch.compile's pattern matcher will recognize and replace during
134+
# compilation.
135+
136+
# Define the pattern we want to match: conv2d followed by batch_norm
137+
def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
138+
conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias)
139+
bn_out = torch.nn.functional.batch_norm(
140+
conv_out, bn_mean, bn_var, bn_weight, bn_bias,
141+
training=False, eps=1e-5
142+
)
143+
return bn_out
144+
145+
def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
146+
fused_weight, fused_bias = fuse_conv_bn_weights(
147+
conv_weight, conv_bias, bn_mean, bn_var, 1e-5, bn_weight, bn_bias
148+
)
149+
return torch.nn.functional.conv2d(x, fused_weight, fused_bias)
150+
151+
# Example inputs are needed to trace the pattern functions.
152+
# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
153+
# These are used to trace the pattern functions to create the match template.
154+
# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
155+
# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
156+
# will be matched regardless of channels, kernel size, or spatial dimensions.
157+
# - x: input tensor (batch_size, channels, height, width)
158+
# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
159+
# - conv_bias: (out_channels,)
160+
# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
161+
example_inputs = [
162+
torch.randn(1, 1, 4, 4).to(device), # x: input tensor
163+
torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel
164+
torch.randn(1).to(device), # conv_bias: 1 output channel
165+
torch.randn(1).to(device), # bn_mean: batch norm running mean
166+
torch.randn(1).to(device), # bn_var: batch norm running variance
167+
torch.randn(1).to(device), # bn_weight: batch norm weight (gamma)
168+
torch.randn(1).to(device), # bn_bias: batch norm bias (beta)
169+
]
170+
171+
from torch._inductor.pattern_matcher import PatternMatcherPass
172+
from torch._inductor import config
173+
174+
# Create a pattern matcher pass and register our pattern
175+
patterns = PatternMatcherPass()
176+
177+
register_replacement(
178+
conv_bn_pattern,
179+
conv_bn_replacement,
180+
example_inputs,
181+
pm.fwd_only,
182+
patterns,
183+
)
184+
185+
# Create a custom pass function that applies our patterns
186+
def conv_bn_fusion_pass(graph):
187+
return patterns.apply(graph)
188+
189+
# Set our custom pass in the config
190+
config.post_grad_custom_post_pass = conv_bn_fusion_pass
195191

196192

197193
######################################################################
198194
# .. note::
199195
# We make some simplifications here for demonstration purposes, such as only
200-
# matching 2D convolutions. View
201-
# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
202-
# for a more usable pass.
196+
# matching 2D convolutions. The pattern matcher in torch.compile
197+
# can handle more complex patterns.
203198

204199
######################################################################
205200
# Testing out our Fusion Pass
@@ -208,11 +203,43 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
208203
# results are identical. In addition, we can print out the code for our fused
209204
# model and verify that there are no more batch norms.
210205

206+
from torch._dynamo.utils import counters
207+
208+
# Clear the counters before compilation
209+
counters.clear()
210+
211+
# Ensure pattern matcher is enabled
212+
config.pattern_matcher = True
211213

212-
fused_model = fuse(model)
213-
print(fused_model.code)
214-
inp = torch.randn(5, 1, 1, 1)
215-
torch.testing.assert_allclose(fused_model(inp), model(inp))
214+
fused_model = torch.compile(model, backend="inductor")
215+
inp = torch.randn(5, 1, 1, 1).to(device)
216+
217+
# Run the model to trigger compilation and pattern matching
218+
with torch.no_grad():
219+
output = fused_model(inp)
220+
expected = model(inp)
221+
torch.testing.assert_close(output, expected)
222+
223+
# Check how many patterns were matched
224+
assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched"
225+
226+
# Create a model with different shapes than our example_inputs
227+
test_model_diff_shape = nn.Sequential(
228+
nn.Conv2d(3, 16, 5),
229+
nn.BatchNorm2d(16),
230+
nn.ReLU(),
231+
nn.Conv2d(16, 32, 7),
232+
nn.BatchNorm2d(32),
233+
).to(device).eval()
234+
235+
counters.clear()
236+
compiled_diff_shape = torch.compile(test_model_diff_shape, backend="inductor")
237+
test_input_diff_shape = torch.randn(1, 3, 28, 28).to(device)
238+
with torch.no_grad():
239+
compiled_diff_shape(test_input_diff_shape)
240+
241+
# Check how many patterns were matched
242+
assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched"
216243

217244

218245
######################################################################
@@ -223,40 +250,38 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
223250
import torchvision.models as models
224251
import time
225252

226-
rn18 = models.resnet18()
253+
rn18 = models.resnet18().to(device)
227254
rn18.eval()
228255

229-
inp = torch.randn(10, 3, 224, 224)
256+
inp = torch.randn(10, 3, 224, 224).to(device)
230257
output = rn18(inp)
231258

232259
def benchmark(model, iters=20):
233-
for _ in range(10):
234-
model(inp)
235-
begin = time.time()
236-
for _ in range(iters):
237-
model(inp)
238-
return str(time.time()-begin)
239-
240-
fused_rn18 = fuse(rn18)
241-
print("Unfused time: ", benchmark(rn18))
242-
print("Fused time: ", benchmark(fused_rn18))
243-
######################################################################
244-
# As we previously saw, the output of our FX transformation is
245-
# ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try
246-
# and increase our performance even more. In this way, our FX model
247-
# transformation composes with TorchScript with no issues.
248-
jit_rn18 = torch.jit.script(fused_rn18)
249-
print("jit time: ", benchmark(jit_rn18))
260+
with torch.no_grad():
261+
for _ in range(10):
262+
model(inp)
263+
begin = time.time()
264+
for _ in range(iters):
265+
model(inp)
266+
return str(time.time()-begin)
267+
268+
# Benchmark original model
269+
print("Original model time: ", benchmark(rn18))
270+
271+
# Compile with our custom pattern
272+
compiled_with_pattern_matching = torch.compile(rn18, backend="inductor")
273+
274+
# Benchmark compiled model
275+
print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching))
250276

251277

252278
############
253279
# Conclusion
254280
# ----------
255-
# As we can see, using FX we can easily write static graph transformations on
256-
# PyTorch code.
281+
# As we can see, torch.compile provides a powerful way to implement
282+
# graph transformations and optimizations through pattern matching.
283+
# By registering custom patterns, we can extend torch.compile's
284+
# optimization capabilities to handle domain-specific transformations.
257285
#
258-
# Since FX is still in beta, we would be happy to hear any
259-
# feedback you have about using it. Please feel free to use the
260-
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
261-
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
262-
# you might have.
286+
# The conv-bn fusion demonstrated here is just one example of what's
287+
# possible with torch.compile's pattern matching system.

0 commit comments

Comments
 (0)