Skip to content

Commit 7bd34b8

Browse files
Arm backend: Fuse duplicate user ops (pytorch#15218)
Adds a pass which checks if a node has multiple users performing equivalent operations its output. If that is the case, it fuses these ops into one. Signed-off-by: Adrian Lundell <[email protected]>
1 parent d77f045 commit 7bd34b8

File tree

5 files changed

+263
-4
lines changed

5 files changed

+263
-4
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
)
7777
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
7878
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
79+
from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa
7980
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
8081
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
8182
from .insert_int32_casts_after_int64_placeholders import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
FoldAndAnnotateQParamsPass,
7878
FuseBatchnorm2DPass,
7979
FuseConstantArgsPass,
80+
FuseDuplicateUsersPass,
8081
FuseEqualPlaceholdersPass,
8182
FuseQuantizedActivationPass,
8283
InsertInt32CastsAfterInt64PlaceholdersPass,
@@ -176,6 +177,7 @@ def _tosa_INT_pipeline(
176177
self.add_pass(QuantizeOperatorArguments())
177178
self.add_pass(ConvertELUParamsPass())
178179
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
180+
self.add_pass(FuseDuplicateUsersPass())
179181
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
180182
self.add_pass(MatchArgRanksPass(exported_program))
181183
if self.tosa_spec.is_U55_subset:
@@ -210,6 +212,7 @@ def _tosa_INT_pipeline(
210212
self.add_pass(RewriteMatmulPass())
211213
self.add_pass(RewriteUpsamplePass())
212214
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
215+
213216
self.add_pass(InsertRescaleInt32Pass())
214217
self.add_pass(DecomposeSumPass())
215218
self.add_pass(ToTosaMemoryFormatPass(exported_program))
@@ -223,6 +226,7 @@ def _tosa_FP_pipeline(
223226
self, exported_program: ExportedProgram, graph_module: GraphModule
224227
) -> GraphModule:
225228
self.add_pass(AnnotateOutputDimOrderPass())
229+
self.add_pass(FuseDuplicateUsersPass())
226230
self.add_pass(DecomposeExpm1Pass())
227231
self.add_pass(DecomposeLogitPass())
228232
self.add_pass(DecomposeMaskedFill())
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from collections import deque
7+
from typing import Any, Deque, Dict, Hashable, List, Set, Tuple, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch._ops import OpOverload
14+
from torch.fx import GraphModule, Node
15+
from torch.fx.node import Argument, map_arg
16+
17+
18+
class FuseDuplicateUsersPass(ArmPass):
19+
"""Fuse identical users of a producer node into a single operation.
20+
21+
Example:
22+
23+
y = producer(x)
24+
z0 = torch.add(y, bias)
25+
z1 = torch.add(y, bias)
26+
27+
becomes a single ``torch.add`` that feeds both consumers.
28+
"""
29+
30+
_passes_required_after: Set[Type[ExportPass]] = set()
31+
32+
def call(self, graph_module: GraphModule) -> PassResult:
33+
graph = graph_module.graph
34+
modified = False
35+
36+
producers: Deque[Node] = deque(node for node in graph.nodes)
37+
38+
while producers:
39+
producer = producers.popleft()
40+
41+
if producer.graph is None:
42+
# Node was deleted by a previous rewrite while still queued.
43+
continue
44+
45+
# Only meaningful if a value is consumed by multiple users.
46+
user_nodes = list(producer.users)
47+
if len(user_nodes) < 2:
48+
continue
49+
50+
candidate_groups = self._get_candidate_groups(user_nodes)
51+
52+
signature_to_user: Dict[Tuple[Hashable, ...], Node] = {}
53+
for group in candidate_groups:
54+
for user in group:
55+
signature = self._build_user_signature(user)
56+
if signature is None:
57+
continue
58+
59+
representative = signature_to_user.get(signature)
60+
if representative is None:
61+
# Check if we already encountered identical node that we can fuse with.
62+
signature_to_user[signature] = user
63+
continue
64+
65+
if user is representative:
66+
# The queue can enqueue the surviving node again after rewrites.
67+
continue
68+
69+
user.replace_all_uses_with(representative)
70+
graph.erase_node(user)
71+
modified = True
72+
73+
# Revisit the current producer and the surviving user so that
74+
# newly formed duplicate chains can be fused in later
75+
# iterations.
76+
producers.append(producer)
77+
producers.append(representative)
78+
79+
if modified:
80+
graph_module.recompile()
81+
graph_module.graph.lint()
82+
graph_module = super().call(graph_module).graph_module
83+
84+
return PassResult(graph_module, modified)
85+
86+
def _get_candidate_groups(self, user_nodes):
87+
users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {}
88+
for user in user_nodes:
89+
if user.graph is None:
90+
# User might already have been removed by a prior rewrite.
91+
continue
92+
93+
if user.op != "call_function":
94+
continue
95+
96+
target_key = self._get_target_key(user.target)
97+
target_signature = (user.op, target_key)
98+
users_by_target.setdefault(target_signature, []).append(user)
99+
100+
candidate_groups = [
101+
group for group in users_by_target.values() if len(group) > 1
102+
]
103+
104+
return candidate_groups
105+
106+
def _build_user_signature(self, node: Node) -> Tuple[Hashable, ...] | None:
107+
try:
108+
normalized_args = self._to_hashable(
109+
map_arg(node.args, self._map_leaf_to_key)
110+
)
111+
normalized_kwargs = self._to_hashable(
112+
{k: map_arg(v, self._map_leaf_to_key) for k, v in node.kwargs.items()}
113+
)
114+
except TypeError:
115+
return None
116+
117+
target_key = self._get_target_key(node.target)
118+
119+
return (node.op, target_key, normalized_args, normalized_kwargs)
120+
121+
def _map_leaf_to_key(self, node: Node) -> Argument:
122+
return node.name
123+
124+
def _to_hashable(self, value: Any) -> Hashable:
125+
"""Convert arbitrarily nested structures into hashable tuples."""
126+
127+
if isinstance(value, (list, tuple)):
128+
return tuple(self._to_hashable(v) for v in value)
129+
if isinstance(value, dict):
130+
normalized_items = [(k, self._to_hashable(v)) for k, v in value.items()]
131+
return tuple(sorted(normalized_items, key=lambda item: repr(item[0])))
132+
if isinstance(value, set):
133+
hashable_values: List[Hashable] = [self._to_hashable(v) for v in value]
134+
return tuple(sorted(hashable_values, key=repr))
135+
if isinstance(value, slice):
136+
return (
137+
"slice",
138+
self._to_hashable(value.start),
139+
self._to_hashable(value.stop),
140+
self._to_hashable(value.step),
141+
)
142+
if isinstance(value, range):
143+
return ("range", value.start, value.stop, value.step)
144+
if isinstance(value, torch.Size):
145+
return ("size", tuple(value))
146+
if isinstance(value, torch.dtype):
147+
return ("dtype", str(value))
148+
if isinstance(value, torch.device):
149+
return ("device", str(value))
150+
if isinstance(value, torch.memory_format):
151+
return ("memory_format", str(value))
152+
if isinstance(value, torch.Tensor):
153+
return (
154+
"tensor",
155+
str(value.dtype),
156+
tuple(value.size()),
157+
value.device.type,
158+
value.requires_grad,
159+
)
160+
return value
161+
162+
def _get_target_key(self, target: Any) -> Hashable:
163+
if isinstance(target, (EdgeOpOverload, OpOverload)):
164+
return str(target)
165+
return target

backends/arm/test/ops/test_matmul.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,13 @@ def test_matmul_u55_INT(test_data: input_t1):
134134
pipeline.run()
135135

136136

137-
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
137+
@common.parametrize(
138+
"test_data",
139+
MatMulSingleInput.test_data_generators,
140+
xfails={
141+
"rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
142+
},
143+
)
138144
@common.XfailIfNoCorstone300
139145
def test_matmul_single_input_u55_INT(test_data: input_t1):
140146
pipeline = EthosU55PipelineINT[input_t1](
@@ -147,7 +153,13 @@ def test_matmul_single_input_u55_INT(test_data: input_t1):
147153
pipeline.run()
148154

149155

150-
@common.parametrize("test_data", MatMulCombo.test_data_generators)
156+
@common.parametrize(
157+
"test_data",
158+
MatMulCombo.test_data_generators,
159+
xfails={
160+
"rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
161+
},
162+
)
151163
@common.XfailIfNoCorstone300
152164
def test_matmul_combo_u55_INT(test_data: input_t1):
153165
pipeline = EthosU55PipelineINT[input_t1](
@@ -173,7 +185,13 @@ def test_matmul_u85_INT(test_data: input_t1):
173185
pipeline.run()
174186

175187

176-
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
188+
@common.parametrize(
189+
"test_data",
190+
MatMulSingleInput.test_data_generators,
191+
xfails={
192+
"rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
193+
},
194+
)
177195
@common.XfailIfNoCorstone320
178196
def test_matmul_single_input_u85_INT(test_data: input_t1):
179197
pipeline = EthosU85PipelineINT[input_t1](
@@ -186,7 +204,13 @@ def test_matmul_single_input_u85_INT(test_data: input_t1):
186204
pipeline.run()
187205

188206

189-
@common.parametrize("test_data", MatMulCombo.test_data_generators)
207+
@common.parametrize(
208+
"test_data",
209+
MatMulCombo.test_data_generators,
210+
xfails={
211+
"rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs"
212+
},
213+
)
190214
@common.XfailIfNoCorstone320
191215
def test_matmul_combo_u85_INT(test_data: input_t1):
192216
pipeline = EthosU85PipelineINT[input_t1](
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes import FuseDuplicateUsersPass
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
12+
13+
input_t = Tuple[torch.Tensor] # Input x
14+
15+
16+
class FuseaAvgPool(torch.nn.Module):
17+
ops_before_pass = {
18+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3,
19+
}
20+
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1}
21+
22+
def __init__(self):
23+
super().__init__()
24+
self.avg = torch.nn.AvgPool2d(1)
25+
26+
def forward(self, x):
27+
return self.avg(x) + self.avg(x) + self.avg(x)
28+
29+
30+
class FuseAvgPoolChain(torch.nn.Module):
31+
ops_before_pass = {
32+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6,
33+
}
34+
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2}
35+
36+
def __init__(self):
37+
super().__init__()
38+
self.avg = torch.nn.AvgPool2d(1)
39+
40+
def forward(self, x):
41+
first = self.avg(self.avg(x))
42+
second = self.avg(self.avg(x))
43+
third = self.avg(self.avg(x))
44+
return first + second + third
45+
46+
47+
modules = {
48+
"fuse_avg_pool": FuseaAvgPool(),
49+
"fuse_avg_pool_chain": FuseAvgPoolChain(),
50+
}
51+
52+
53+
@common.parametrize("module", modules)
54+
def test_fuse_duplicate_ops_FP(module: torch.nn.Module):
55+
pipeline = PassPipeline[input_t](
56+
module=module,
57+
test_data=(torch.ones(1, 1, 1, 1),),
58+
quantize=False,
59+
ops_before_pass=module.ops_before_pass,
60+
ops_after_pass=module.ops_after_pass,
61+
pass_list=[
62+
FuseDuplicateUsersPass,
63+
],
64+
)
65+
pipeline.run()

0 commit comments

Comments
 (0)