Skip to content

Commit b705f17

Browse files
yushangdizou3519
authored andcommitted
[functorch] Implement a Common Subexpression Elimination (CSE) pass in AOTAutograd (pytorch/functorch#852)
Implement a Common Subexpression Elimination (CSE) pass in AOTAutograd The test is in test/test_memory_efficient_fusion.py.
1 parent 62812dc commit b705f17

File tree

4 files changed

+354
-0
lines changed

4 files changed

+354
-0
lines changed

functorch/benchmarks/cse.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
import torch.fx as fx
3+
from functorch import make_fx
4+
from torch.profiler import profile, ProfilerActivity
5+
6+
from functorch._src.cse import fx_graph_cse
7+
8+
def profile_it(f, inp):
9+
for _ in range(5):
10+
f(inp)
11+
12+
itr = 5
13+
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
14+
for _ in range(itr):
15+
f(inp)
16+
17+
timing = prof.key_averages()
18+
cuda_time_total = 0
19+
for e in timing:
20+
cuda_time_total = cuda_time_total + e.cuda_time_total
21+
return cuda_time_total / itr
22+
23+
def profile_function(name, f, inp):
24+
fx_g = make_fx(f)(inp)
25+
26+
new_g = fx_graph_cse(fx_g.graph)
27+
new_g = fx.GraphModule(fx_g, new_g)
28+
# do not benchmark against the scripted version because script already does some CSE
29+
# script_f = torch.jit.script(fx_g)
30+
# script_g = torch.jit.script(new_g)
31+
# avg_cuda_time_f = profile_it(script_f, inp)
32+
# avg_cuda_time_g = profile_it(script_g, inp)
33+
avg_cuda_time_f = profile_it(fx_g, inp)
34+
avg_cuda_time_g = profile_it(new_g, inp)
35+
num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
36+
37+
print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
38+
39+
g_gpu = torch.Generator(device='cuda')
40+
g_gpu.manual_seed(2147483647)
41+
inp = torch.randn(2**20, device='cuda', generator=g_gpu)
42+
43+
def f1(x):
44+
return x.cos().cos()
45+
46+
profile_function("f1", f1, inp)
47+
48+
def fsum(x):
49+
a = x.sum()
50+
b = x.sum()
51+
c = x.sum()
52+
d = x.sum()
53+
return a + b + c + d
54+
55+
profile_function("fsum", fsum, inp)
56+
57+
def fconcat(x):
58+
a = torch.cat((x, x))
59+
b = torch.cat((x, x))
60+
return a + b
61+
profile_function("fconcat", fconcat, inp)
62+
63+
def fsum2(x):
64+
a = x.sum()
65+
for _ in range(30):
66+
a = a + x.sum()
67+
return a
68+
69+
profile_function("fsum2", fsum2, inp)
70+
71+
def fsummulti(x):
72+
a = 0
73+
for _ in range(3):
74+
a = a + x.sum()
75+
a = a * x.sum()
76+
return a
77+
78+
profile_function("fsummulti", fsummulti, inp)
79+
80+
def fsummulti2(x):
81+
a = 0
82+
for _ in range(30):
83+
a = a + x.sum()
84+
a = a * x.sum()
85+
return a
86+
87+
profile_function("fsummulti2", fsummulti2, inp)
88+
89+
def fcos(x):
90+
a = 0
91+
for _ in range(3):
92+
a = a + x.cos()
93+
return a
94+
95+
profile_function("fcos", fcos, inp)
96+
97+
def fcos2(x):
98+
a = 0
99+
for _ in range(30):
100+
a = a + x.cos()
101+
return a
102+
103+
profile_function("fcos2", fcos2, inp)

functorch/functorch/_src/cse.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import torch.fx as fx
3+
from torch.utils._pytree import tree_flatten
4+
5+
aten = torch.ops.aten
6+
rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
7+
aten.bernoulli, aten.multinomial, aten.native_dropout,
8+
aten.normal, aten.poisson, aten.binomial, aten.rrelu,
9+
aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]
10+
11+
12+
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
13+
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
14+
new_graph = fx.Graph()
15+
env = {} # map from node in the old graph to node in the new graph
16+
hash_env = {} # map from hash to a node in the new graph
17+
token_map = {} # map from hash to token
18+
for n in fx_g.nodes:
19+
# The placeholder, output, and get_attr nodes are copied to the new grpah without change
20+
# do not CSE away random operations
21+
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or n.target in rand_ops:
22+
new_node = new_graph.node_copy(n, lambda x: env[x])
23+
env[n] = new_node
24+
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
25+
# substitute args and kwargs memebrs to their mapping in env if exists
26+
# specs can be used to reconstruct nested list/dictionaries
27+
def substitute(arg_list):
28+
arg_list, spec = tree_flatten(arg_list)
29+
for i in range(len(arg_list)):
30+
v = arg_list[i]
31+
if isinstance(v, torch.fx.node.Node) and v in env:
32+
arg_list[i] = env[v]
33+
return tuple(arg_list), spec
34+
args, args_spec = substitute(n.args)
35+
kwargs, kwargs_spec = substitute(n.kwargs)
36+
37+
# each token corresponds to a unique node
38+
# nodes with the same token can be substituted
39+
token = {"target": n.target, "args": args, "args_spec": args_spec,
40+
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
41+
42+
# hash substituted args to a number, do not hash specs because specs are not hashable
43+
hash_arg = hash((args, kwargs))
44+
hash_val = (n.target, hash_arg)
45+
46+
# check if a node has a substitute and can be eliminated
47+
hash_val_in_hash_env = hash_val in hash_env
48+
if hash_val_in_hash_env and token_map[hash_val] == token:
49+
env[n] = hash_env[hash_val]
50+
continue
51+
52+
new_node = new_graph.node_copy(n, lambda x: env[x])
53+
env[n] = new_node
54+
if not hash_val_in_hash_env:
55+
hash_env[hash_val] = new_node
56+
token_map[hash_val] = token
57+
58+
return new_graph

functorch/functorch/_src/partitioners.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
from torch.fx.passes import graph_drawer
99
from typing import Tuple
10+
from .cse import fx_graph_cse
1011

1112

1213
class InvalidNodeBase(object):
@@ -226,6 +227,10 @@ def min_cut_rematerialization_partition(
226227
except ImportError:
227228
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
228229

230+
# add the CSE pass
231+
fx_g = joint_module.graph
232+
cse_graph = fx_graph_cse(fx_g)
233+
joint_module.graph = cse_graph
229234
full_bw_graph = joint_module.graph
230235

231236
name_to_node = {}

functorch/test/test_memory_efficient_fusion.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import torch
22
import torch.nn as nn
3+
import torch.fx as fx
4+
from functorch import make_fx
35
from torch.nn import functional as F
46
from functorch.compile import memory_efficient_fusion
7+
from functorch._src.cse import fx_graph_cse
58
from torch.testing._internal.common_utils import TestCase, run_tests
69
import inspect
10+
import random
711
from typing import Callable
812
import unittest
913

@@ -179,5 +183,189 @@ def forward(self, hidden_states):
179183
# run_and_compare_activation(hard_mish, 1024)
180184

181185

186+
# check if the CSE modified graph of f has delta less nodes, and do not reduce the number of nodes further on a second pass.
187+
# delta is an integer >= -1. If delta = -1, only check if the new graph
188+
# has less or equal number of nodes
189+
def check(f, t, delta, check_val=True, graph_input=False):
190+
if graph_input:
191+
fx_g = f
192+
else:
193+
fx_g = make_fx(f)(t)
194+
new_graph = fx_graph_cse(fx_g.graph)
195+
new_g = fx.GraphModule(fx_g, new_graph)
196+
197+
# the number of nodes decrease/ or stay the same
198+
old_num_nodes = len(fx_g.graph.nodes)
199+
new_num_nodes = len(new_graph.nodes)
200+
if delta == -1:
201+
assert old_num_nodes >= new_num_nodes, (
202+
f"number of nodes increased {old_num_nodes}, {new_num_nodes}")
203+
else:
204+
assert old_num_nodes == new_num_nodes + delta, (
205+
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")
206+
207+
# a second pass should not reduce more nodes
208+
pass_2_graph = fx_graph_cse(new_graph)
209+
pass_2_num_nodes = len(pass_2_graph.nodes)
210+
assert pass_2_num_nodes == new_num_nodes, (
211+
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")
212+
213+
# check correctness
214+
if check_val:
215+
true_result = fx_g(t)
216+
our_result = new_g(t)
217+
if true_result is None: # both return None
218+
assert our_result is None, f"true result is None, CSE result is {our_result}"
219+
else: # results returned are the same
220+
assert torch.all(true_result == our_result), (
221+
f"results are different {true_result}, {our_result}") # check results are the same
222+
223+
224+
class NoChangeTestCase(TestCase):
225+
226+
def test_nochange(self):
227+
def f(x):
228+
a = x + 1
229+
b = x + a
230+
a = x
231+
d = x + a
232+
return b + d
233+
t = torch.randn(2, 2)
234+
check(f, t, 0)
235+
236+
def test_empty(self):
237+
def f(x):
238+
pass
239+
t = torch.randn(2, 2)
240+
check(f, t, 0)
241+
242+
def test_rand_like(self):
243+
def f(x):
244+
a = torch.rand_like(x)
245+
b = torch.rand_like(x)
246+
return a + b
247+
t = torch.randn(2, 2)
248+
check(f, t, 0, check_val=False)
249+
250+
def test_rand_n(self):
251+
def f(x):
252+
g_cpu = torch.Generator()
253+
g_cpu.manual_seed(2147483647)
254+
a = torch.randn(4, generator=g_cpu)
255+
b = torch.randn(4, generator=g_cpu)
256+
return a + b
257+
t = torch.randn(2, 2)
258+
check(f, t, 0)
259+
260+
261+
class ReduceTestCase(TestCase):
262+
263+
def test_immutable_list_type(self):
264+
def f(x):
265+
a = x.sum(dim=1)
266+
b = x.sum(dim=1)
267+
c = x.sum()
268+
d = x.sum()
269+
return a + b + c + d
270+
t = torch.randn(2, 2)
271+
check(f, t, 2)
272+
273+
def test_immutable_list_multiple_entries(self):
274+
def f(x):
275+
a = x.sum(dim=[0, 1])
276+
b = x.sum(dim=[0, 1])
277+
c = x.sum(dim=1)
278+
d = x.sum(dim=1)
279+
return a + b + c + d
280+
t = torch.randn(2, 2)
281+
check(f, t, 2)
282+
283+
def test_simple(self):
284+
def f(x):
285+
a = x.cos()
286+
b = x.cos()
287+
c = a + a
288+
d = b + b
289+
return c + d
290+
t = torch.randn(2, 2)
291+
check(f, t, 2)
292+
293+
def test_simple_2(self):
294+
def f(x):
295+
a = x.cos().sin()
296+
b = x.cos().sin()
297+
c = a + a
298+
d = b + b
299+
return c + d
300+
t = torch.randn(1)
301+
check(f, t, 3)
302+
303+
def test_two_args_default(self):
304+
def f(x):
305+
a = x.sum(dim=1)
306+
b = x.sum(dim=1, keepdim=False)
307+
c = x.sum(dim=1, keepdim=False)
308+
d = x.sum(dim=1)
309+
return a + b + c + d
310+
t = torch.randn(2, 2)
311+
check(f, t, 3)
312+
313+
def test_two_args(self):
314+
def f(x):
315+
a = x.sum(dim=1)
316+
b = x.sum(dim=1, keepdim=True)
317+
c = x.sum(dim=1, keepdim=True)
318+
d = x.sum(dim=1)
319+
return a + b + c + d
320+
t = torch.randn(2, 2)
321+
check(f, t, 2)
322+
323+
def test_simple_multiple_same_ops(self):
324+
def f(x):
325+
a = x.sum()
326+
b = x.sum()
327+
c = x.sum()
328+
d = x.sum()
329+
return a + b + c + d
330+
t = torch.randn(2, 2)
331+
check(f, t, 3)
332+
333+
def test_nested_immutable_list_type(self):
334+
def f(x):
335+
a = torch.cat((x, x))
336+
b = torch.cat((x, x))
337+
return a + b
338+
t = torch.randn(2, 2)
339+
check(f, t, 1)
340+
341+
def test_kwarg(self):
342+
def f(x):
343+
a = torch.ones_like(x)
344+
b = torch.ones_like(x)
345+
return a + b
346+
t = torch.randn(2, 2)
347+
check(f, t, 1)
348+
349+
350+
class RandomOpTestCase(TestCase):
351+
def test_random(self):
352+
def f(x):
353+
vals = [x]
354+
ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu]
355+
for _ in range(100):
356+
new_val = random.choice(ops)(random.choice(vals))
357+
vals.append(new_val)
358+
return vals[-1]
359+
360+
fx_g = fx.symbolic_trace(f)
361+
fx_g.graph.eliminate_dead_code()
362+
fx_g.recompile()
363+
t = torch.randn(2, 2)
364+
365+
for _ in range(30):
366+
check(fx_g, t, -1, graph_input=True)
367+
368+
369+
182370
if __name__ == "__main__":
183371
run_tests()

0 commit comments

Comments
 (0)