Skip to content

Commit 640703d

Browse files
mengluy0125pytorchmergebot
authored andcommitted
add torch.concat to normalization pass (pytorch#156574)
Summary: In the normalization pass, we also add torch.concat to it to normalize it as torch.cat Test Plan: ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_passes -- test_cat_normalization ``` Buck UI: https://www.internalfb.com/buck2/597fd4f1-0aa7-4372-8a66-5a690d9b63a4 Test UI: https://www.internalfb.com/intern/testinfra/testrun/1688850152284203 Network: Up: 84KiB Down: 34KiB (reSessionID-3916e009-7117-41ce-b6f9-089873aa50dd) Executing actions. Remaining 0/3 1.1s exec time total Command: test. Finished 2 local Time elapsed: 3:47.1s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 Rollback Plan: Differential Revision: D77125331 Pull Request resolved: pytorch#156574 Approved by: https://github.com/Mingming-Ding
1 parent 1155c53 commit 640703d

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

test/inductor/test_split_cat_fx_passes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,33 @@ def normalize_reshape_with_dynamic_shape(x):
115115
)
116116
counters.clear()
117117

118+
@torch._inductor.config.patch(
119+
pre_grad_fusion_options={
120+
"normalization_pass": {},
121+
},
122+
post_grad_fusion_options={},
123+
)
124+
def test_cat_normalization(self):
125+
def caoncat_only(x):
126+
return torch.concat(list(torch.split(x, 2, 1)), dim=1)
127+
128+
args = [
129+
torch.randn(2, 32),
130+
]
131+
for fn, dynamic, expected_cat_norm_count in [
132+
(caoncat_only, False, 2),
133+
]:
134+
expected = fn(*args)
135+
actual = torch.compile(fn, dynamic=dynamic)(*args)
136+
137+
torch.testing.assert_close(actual, expected)
138+
self.assertEqual(
139+
counters["inductor"]["normalization_pass"],
140+
expected_cat_norm_count,
141+
msg=f"for {fn}",
142+
)
143+
counters.clear()
144+
118145
@patch
119146
def test_consecutive_split_merge(self):
120147
def multi_split(x):

torch/_inductor/fx_passes/split_cat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
302302

303303

304304
@register_graph_pattern(
305-
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
305+
CallFunctionVarArgs([torch.cat, torch.concat], users=MULTIPLE),
306306
pass_dict=construct_pattern_matcher_pass("normalization_pass"),
307307
)
308308
def normalize_cat_default(match: Match, *args, **kwargs):
@@ -347,6 +347,7 @@ def is_empty_tensor(x):
347347
cat_node.args == new_args
348348
and cat_node.kwargs == new_kwargs
349349
and cat_node.op == "call_function"
350+
and cat_node.target == torch.cat
350351
):
351352
return
352353

0 commit comments

Comments
 (0)