Skip to content

Commit bd019c0

Browse files
leslie-fang-intelpytorchmergebot
authored andcommitted
[Inductor][CPP] Fix node name for wgt delete (pytorch#147056)
**Summary** This is a regression issue caused by a change in the FX node name. In commit 71010bf, both the node name and target for the `get_attr` node in `V.graph.graph.nodes` were `_frozen_param2`. However, in the latest main, the node name has changed to `_reorder_linear_weight`. This PR fixes the regression by using the node's target instead of its name. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_cpp_weight_prune ``` Pull Request resolved: pytorch#147056 Approved by: https://github.com/jgong5
1 parent 10bc8f2 commit bd019c0

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

test/inductor/test_cpu_select_algorithm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,29 @@ def forward(self, x):
20512051
self.common(mod, (v,), atol=atol, rtol=rtol)
20522052
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
20532053

2054+
@inductor_config.patch({"freezing": True})
2055+
@patches
2056+
@torch.no_grad
2057+
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
2058+
def test_cpp_weight_prune(self):
2059+
class M(torch.nn.Module):
2060+
def __init__(self):
2061+
super().__init__()
2062+
self.linear = torch.nn.Linear(32, 128, bias=False)
2063+
2064+
def forward(self, x):
2065+
return self.linear(x)
2066+
2067+
v = torch.randn(2, 32).to(torch.bfloat16)
2068+
mod = M().eval().to(torch.bfloat16)
2069+
torch._dynamo.reset()
2070+
torch._inductor.metrics.reset()
2071+
counters.clear()
2072+
with verify(torch.bfloat16) as (atol, rtol):
2073+
self.common(mod, (v,), atol=atol, rtol=rtol)
2074+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
2075+
self.assertEqual(counters["inductor"]["select_algorithm_weight_prune"], 1)
2076+
20542077
@patches
20552078
@torch.no_grad
20562079
@unittest.skipIf(not TEST_MKL, "Test requires MKL")

torch/_inductor/codegen/cpp_gemm_template.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,9 @@ def get_candidates(input_nodes, new_input_nodes):
383383
# Case may happen when the candidate tensor is used by more than 1 get_attr node
384384
# https://github.com/pytorch/pytorch/issues/134998
385385
if node.op == "get_attr" and hasattr(
386-
V.graph.module, node.name
386+
V.graph.module, node.target
387387
): # candidate tensor might already be deleted
388-
comp_tensor = getattr(V.graph.module, node.name)
388+
comp_tensor = getattr(V.graph.module, node.target)
389389
if isinstance(comp_tensor, torch.Tensor) and share_storage(
390390
candidate_tensor, comp_tensor
391391
):
@@ -395,13 +395,15 @@ def get_candidates(input_nodes, new_input_nodes):
395395
# The get_attr node has only 1 user fx node
396396
# The candidate tensor has been used by only 1 get_attr node
397397
if (
398-
node.name == candidate_node.get_name()
398+
node.op == "get_attr"
399+
and node.target == candidate_node.get_name()
399400
and len(node.users) == 1
400401
and candidate_tensor_users == 1
401402
):
402-
del V.graph.constants[node.name]
403-
delattr(V.graph.module, node.name)
404-
delattr(V.graph.graph.owning_module, node.name)
403+
del V.graph.constants[node.target]
404+
delattr(V.graph.module, node.target)
405+
delattr(V.graph.graph.owning_module, node.target)
406+
counters["inductor"]["select_algorithm_weight_prune"] += 1
405407

406408

407409
def gen_2d_view_of_epilogue_buf(

0 commit comments

Comments
 (0)