1
1
# -*- coding: utf-8 -*-
2
2
"""
3
- (beta) Building a Convolution/Batch Norm fuser in FX
4
- *******************************************************
5
- **Author**: `Horace He <https://github.com/chillee>`_
3
+ Building a Convolution/Batch Norm fuser with torch.compile
4
+ ******************************************************************
5
+ **Author**: `Horace He <https://github.com/chillee>`__, `Will Feng <https://github.com/yf225>`__
6
6
7
- In this tutorial, we are going to use FX, a toolkit for composable function
8
- transformations of PyTorch, to do the following:
7
+ In this tutorial, we are going to use torch.compile and its pattern matching
8
+ capabilities to do the following:
9
9
10
10
1) Find patterns of conv/batch norm in the data dependencies.
11
11
2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
12
12
13
- Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
13
+ Note that this specific optimization only works for models in inference mode (i.e. `mode.eval()`).
14
+ But the pattern matching system in torch.compile works for both training and inference.
14
15
15
- We will be building the fuser that exists here:
16
- https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
16
+ We will demonstrate how to register custom fusion patterns with torch.compile's
17
+ pattern matcher to optimize model performance.
17
18
18
19
"""
19
20
24
25
25
26
from typing import Type , Dict , Any , Tuple , Iterable
26
27
import copy
27
- import torch .fx as fx
28
28
import torch
29
29
import torch .nn as nn
30
30
31
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
32
+
31
33
######################################################################
32
34
# For this tutorial, we are going to create a model consisting of convolutions
33
35
# and batch norms. Note that this model has some tricky components - some of
@@ -61,29 +63,26 @@ def forward(self, x):
61
63
x = self .wrapped (x )
62
64
return x
63
65
64
- model = M ()
65
-
66
+ model = M ().to (device )
66
67
model .eval ()
67
68
68
69
######################################################################
69
70
# Fusing Convolution with Batch Norm
70
71
# -----------------------------------------
71
72
# One of the primary challenges with trying to automatically fuse convolution
72
73
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
73
- # accessing the computational graph. FX resolves this problem by symbolically
74
- # tracing the actual operations called, so that we can track the computations
75
- # through the `forward` call, nested within Sequential modules, or wrapped in
76
- # an user-defined module.
77
-
78
- traced_model = torch .fx .symbolic_trace (model )
79
- print (traced_model .graph )
74
+ # accessing the computational graph. torch.compile resolves this problem by
75
+ # capturing the computational graph during compilation, allowing us to apply
76
+ # pattern-based optimizations across the entire model, including operations
77
+ # nested within Sequential modules or wrapped in custom modules.
78
+ import torch ._inductor .pattern_matcher as pm
79
+ from torch ._inductor .pattern_matcher import register_replacement
80
80
81
81
######################################################################
82
- # This gives us a graph representation of our model. Note that both the modules
83
- # hidden within the sequential as well as the wrapped Module have been inlined
84
- # into the graph. This is the default level of abstraction, but it can be
85
- # configured by the pass writer. More information can be found at the FX
86
- # overview https://pytorch.org/docs/master/fx.html#module-torch.fx
82
+ # torch.compile will capture a graph representation of our model. During
83
+ # compilation, modules hidden within Sequential containers and wrapped
84
+ # modules are all inlined into the graph, making them available for
85
+ # pattern matching and optimization.
87
86
88
87
89
88
####################################
@@ -128,78 +127,74 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
128
127
129
128
130
129
####################################
131
- # FX Fusion Pass
132
- # ----------------------------------
133
- # Now that we have our computational graph as well as a method for fusing
134
- # convolution and batch norm, all that remains is to iterate over the FX graph
135
- # and apply the desired fusions.
136
-
137
-
138
- def _parent_name (target : str ) -> Tuple [str , str ]:
139
- """
140
- Splits a ``qualname`` into parent path and last atom.
141
- For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
142
- """
143
- * parent , name = target .rsplit ('.' , 1 )
144
- return parent [0 ] if parent else '' , name
145
-
146
- def replace_node_module (node : fx .Node , modules : Dict [str , Any ], new_module : torch .nn .Module ):
147
- assert (isinstance (node .target , str ))
148
- parent_name , name = _parent_name (node .target )
149
- setattr (modules [parent_name ], name , new_module )
150
-
151
-
152
- def fuse (model : torch .nn .Module ) -> torch .nn .Module :
153
- model = copy .deepcopy (model )
154
- # The first step of most FX passes is to symbolically trace our model to
155
- # obtain a `GraphModule`. This is a representation of our original model
156
- # that is functionally identical to our original model, except that we now
157
- # also have a graph representation of our forward pass.
158
- fx_model : fx .GraphModule = fx .symbolic_trace (model )
159
- modules = dict (fx_model .named_modules ())
160
-
161
- # The primary representation for working with FX are the `Graph` and the
162
- # `Node`. Each `GraphModule` has a `Graph` associated with it - this
163
- # `Graph` is also what generates `GraphModule.code`.
164
- # The `Graph` itself is represented as a list of `Node` objects. Thus, to
165
- # iterate through all of the operations in our graph, we iterate over each
166
- # `Node` in our `Graph`.
167
- for node in fx_model .graph .nodes :
168
- # The FX IR contains several types of nodes, which generally represent
169
- # call sites to modules, functions, or methods. The type of node is
170
- # determined by `Node.op`.
171
- if node .op != 'call_module' : # If our current node isn't calling a Module then we can ignore it.
172
- continue
173
- # For call sites, `Node.target` represents the module/function/method
174
- # that's being called. Here, we check `Node.target` to see if it's a
175
- # batch norm module, and then check `Node.args[0].target` to see if the
176
- # input `Node` is a convolution.
177
- if type (modules [node .target ]) is nn .BatchNorm2d and type (modules [node .args [0 ].target ]) is nn .Conv2d :
178
- if len (node .args [0 ].users ) > 1 : # Output of conv is used by other nodes
179
- continue
180
- conv = modules [node .args [0 ].target ]
181
- bn = modules [node .target ]
182
- fused_conv = fuse_conv_bn_eval (conv , bn )
183
- replace_node_module (node .args [0 ], modules , fused_conv )
184
- # As we've folded the batch nor into the conv, we need to replace all uses
185
- # of the batch norm with the conv.
186
- node .replace_all_uses_with (node .args [0 ])
187
- # Now that all uses of the batch norm have been replaced, we can
188
- # safely remove the batch norm.
189
- fx_model .graph .erase_node (node )
190
- fx_model .graph .lint ()
191
- # After we've modified our graph, we need to recompile our graph in order
192
- # to keep the generated code in sync.
193
- fx_model .recompile ()
194
- return fx_model
130
+ # Pattern Matching with torch.compile
131
+ # ------------------------------------
132
+ # Now that we have our fusion logic, we need to register a pattern that
133
+ # torch.compile's pattern matcher will recognize and replace during
134
+ # compilation.
135
+
136
+ # Define the pattern we want to match: conv2d followed by batch_norm
137
+ def conv_bn_pattern (x , conv_weight , conv_bias , bn_mean , bn_var , bn_weight , bn_bias ):
138
+ conv_out = torch .nn .functional .conv2d (x , conv_weight , conv_bias )
139
+ bn_out = torch .nn .functional .batch_norm (
140
+ conv_out , bn_mean , bn_var , bn_weight , bn_bias ,
141
+ training = False , eps = 1e-5
142
+ )
143
+ return bn_out
144
+
145
+ def conv_bn_replacement (x , conv_weight , conv_bias , bn_mean , bn_var , bn_weight , bn_bias ):
146
+ fused_weight , fused_bias = fuse_conv_bn_weights (
147
+ conv_weight , conv_bias , bn_mean , bn_var , 1e-5 , bn_weight , bn_bias
148
+ )
149
+ return torch .nn .functional .conv2d (x , fused_weight , fused_bias )
150
+
151
+ # Example inputs are needed to trace the pattern functions.
152
+ # The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
153
+ # These are used to trace the pattern functions to create the match template.
154
+ # IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
155
+ # don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
156
+ # will be matched regardless of channels, kernel size, or spatial dimensions.
157
+ # - x: input tensor (batch_size, channels, height, width)
158
+ # - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
159
+ # - conv_bias: (out_channels,)
160
+ # - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
161
+ example_inputs = [
162
+ torch .randn (1 , 1 , 4 , 4 ).to (device ), # x: input tensor
163
+ torch .randn (1 , 1 , 1 , 1 ).to (device ), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel
164
+ torch .randn (1 ).to (device ), # conv_bias: 1 output channel
165
+ torch .randn (1 ).to (device ), # bn_mean: batch norm running mean
166
+ torch .randn (1 ).to (device ), # bn_var: batch norm running variance
167
+ torch .randn (1 ).to (device ), # bn_weight: batch norm weight (gamma)
168
+ torch .randn (1 ).to (device ), # bn_bias: batch norm bias (beta)
169
+ ]
170
+
171
+ from torch ._inductor .pattern_matcher import PatternMatcherPass
172
+ from torch ._inductor import config
173
+
174
+ # Create a pattern matcher pass and register our pattern
175
+ patterns = PatternMatcherPass ()
176
+
177
+ register_replacement (
178
+ conv_bn_pattern ,
179
+ conv_bn_replacement ,
180
+ example_inputs ,
181
+ pm .fwd_only ,
182
+ patterns ,
183
+ )
184
+
185
+ # Create a custom pass function that applies our patterns
186
+ def conv_bn_fusion_pass (graph ):
187
+ return patterns .apply (graph )
188
+
189
+ # Set our custom pass in the config
190
+ config .post_grad_custom_post_pass = conv_bn_fusion_pass
195
191
196
192
197
193
######################################################################
198
194
# .. note::
199
195
# We make some simplifications here for demonstration purposes, such as only
200
- # matching 2D convolutions. View
201
- # https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
202
- # for a more usable pass.
196
+ # matching 2D convolutions. The pattern matcher in torch.compile
197
+ # can handle more complex patterns.
203
198
204
199
######################################################################
205
200
# Testing out our Fusion Pass
@@ -208,11 +203,43 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
208
203
# results are identical. In addition, we can print out the code for our fused
209
204
# model and verify that there are no more batch norms.
210
205
206
+ from torch ._dynamo .utils import counters
207
+
208
+ # Clear the counters before compilation
209
+ counters .clear ()
210
+
211
+ # Ensure pattern matcher is enabled
212
+ config .pattern_matcher = True
211
213
212
- fused_model = fuse (model )
213
- print (fused_model .code )
214
- inp = torch .randn (5 , 1 , 1 , 1 )
215
- torch .testing .assert_allclose (fused_model (inp ), model (inp ))
214
+ fused_model = torch .compile (model , backend = "inductor" )
215
+ inp = torch .randn (5 , 1 , 1 , 1 ).to (device )
216
+
217
+ # Run the model to trigger compilation and pattern matching
218
+ with torch .no_grad ():
219
+ output = fused_model (inp )
220
+ expected = model (inp )
221
+ torch .testing .assert_close (output , expected )
222
+
223
+ # Check how many patterns were matched
224
+ assert counters ['inductor' ]['pattern_matcher_count' ] == 3 , "Expected 3 conv-bn patterns to be matched"
225
+
226
+ # Create a model with different shapes than our example_inputs
227
+ test_model_diff_shape = nn .Sequential (
228
+ nn .Conv2d (3 , 16 , 5 ),
229
+ nn .BatchNorm2d (16 ),
230
+ nn .ReLU (),
231
+ nn .Conv2d (16 , 32 , 7 ),
232
+ nn .BatchNorm2d (32 ),
233
+ ).to (device ).eval ()
234
+
235
+ counters .clear ()
236
+ compiled_diff_shape = torch .compile (test_model_diff_shape , backend = "inductor" )
237
+ test_input_diff_shape = torch .randn (1 , 3 , 28 , 28 ).to (device )
238
+ with torch .no_grad ():
239
+ compiled_diff_shape (test_input_diff_shape )
240
+
241
+ # Check how many patterns were matched
242
+ assert counters ['inductor' ]['pattern_matcher_count' ] == 2 , "Expected 2 conv-bn patterns to be matched"
216
243
217
244
218
245
######################################################################
@@ -223,40 +250,38 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
223
250
import torchvision .models as models
224
251
import time
225
252
226
- rn18 = models .resnet18 ()
253
+ rn18 = models .resnet18 (). to ( device )
227
254
rn18 .eval ()
228
255
229
- inp = torch .randn (10 , 3 , 224 , 224 )
256
+ inp = torch .randn (10 , 3 , 224 , 224 ). to ( device )
230
257
output = rn18 (inp )
231
258
232
259
def benchmark (model , iters = 20 ):
233
- for _ in range (10 ):
234
- model (inp )
235
- begin = time .time ()
236
- for _ in range (iters ):
237
- model (inp )
238
- return str (time .time ()- begin )
239
-
240
- fused_rn18 = fuse (rn18 )
241
- print ("Unfused time: " , benchmark (rn18 ))
242
- print ("Fused time: " , benchmark (fused_rn18 ))
243
- ######################################################################
244
- # As we previously saw, the output of our FX transformation is
245
- # ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try
246
- # and increase our performance even more. In this way, our FX model
247
- # transformation composes with TorchScript with no issues.
248
- jit_rn18 = torch .jit .script (fused_rn18 )
249
- print ("jit time: " , benchmark (jit_rn18 ))
260
+ with torch .no_grad ():
261
+ for _ in range (10 ):
262
+ model (inp )
263
+ begin = time .time ()
264
+ for _ in range (iters ):
265
+ model (inp )
266
+ return str (time .time ()- begin )
267
+
268
+ # Benchmark original model
269
+ print ("Original model time: " , benchmark (rn18 ))
270
+
271
+ # Compile with our custom pattern
272
+ compiled_with_pattern_matching = torch .compile (rn18 , backend = "inductor" )
273
+
274
+ # Benchmark compiled model
275
+ print ("\n torch.compile (with conv-bn pattern matching and other fusions): " , benchmark (compiled_with_pattern_matching ))
250
276
251
277
252
278
############
253
279
# Conclusion
254
280
# ----------
255
- # As we can see, using FX we can easily write static graph transformations on
256
- # PyTorch code.
281
+ # As we can see, torch.compile provides a powerful way to implement
282
+ # graph transformations and optimizations through pattern matching.
283
+ # By registering custom patterns, we can extend torch.compile's
284
+ # optimization capabilities to handle domain-specific transformations.
257
285
#
258
- # Since FX is still in beta, we would be happy to hear any
259
- # feedback you have about using it. Please feel free to use the
260
- # PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
261
- # (https://github.com/pytorch/pytorch/issues) to provide any feedback
262
- # you might have.
286
+ # The conv-bn fusion demonstrated here is just one example of what's
287
+ # possible with torch.compile's pattern matching system.
0 commit comments