Skip to content

Commit 05dd638

Browse files
1 parent a3098a7 commit 05dd638

File tree

2 files changed

+0
-49
lines changed

2 files changed

+0
-49
lines changed

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13550,40 +13550,6 @@ def test_special_polygamma(self):
1355013550
self.common(fn, (1, x))
1355113551
self.common(fn, (2, x))
1355213552

13553-
@config.patch({"freezing": True})
13554-
def test_dont_constant_fold(self):
13555-
from torch._inductor.constant_folding import (
13556-
add_dont_constant_fold,
13557-
clear_dont_constant_fold,
13558-
)
13559-
13560-
m = 5
13561-
13562-
class M(torch.nn.Module):
13563-
def __init__(self):
13564-
super().__init__()
13565-
self.w = torch.randn(m)
13566-
self.s = torch.randn(m)
13567-
13568-
def forward(self, x):
13569-
return self.w * self.s + x
13570-
13571-
x = torch.rand(m)
13572-
mod = M()
13573-
for dont_constant_fold in [True, False]:
13574-
clear_dont_constant_fold()
13575-
if dont_constant_fold:
13576-
add_dont_constant_fold(torch.ops.aten.mul.Tensor)
13577-
with torch.no_grad():
13578-
refe_out = mod(x)
13579-
mod = torch.compile(mod)
13580-
test_out, (code,) = run_and_get_code(mod, x)
13581-
if dont_constant_fold:
13582-
FileCheck().check("cpp_fused_add_mul").run(code)
13583-
else:
13584-
FileCheck().check("cpp_fused_add_0").run(code)
13585-
self.assertEqual(refe_out, test_out)
13586-
1358713553

1358813554
@dataclasses.dataclass
1358913555
class TestFailure:

torch/_inductor/constant_folding.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,6 @@
1616
MODULE_TAG = "_MAIN_MODULE"
1717
CONST_MODULE_TAG = "_CONST_MODULE"
1818

19-
_dont_constant_fold: list[torch.fx.node.Target] = []
20-
21-
22-
def add_dont_constant_fold(op: torch.fx.node.Target) -> None:
23-
global _dont_constant_fold
24-
_dont_constant_fold.append(op)
25-
26-
27-
def clear_dont_constant_fold() -> None:
28-
global _dont_constant_fold
29-
_dont_constant_fold.clear()
30-
3119

3220
def replace_node_with_constant(
3321
gm: torch.fx.GraphModule,
@@ -158,9 +146,6 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
158146
# We only folding fp32_weight -> q
159147
# int8_weight and leave dq in graph to be fused
160148
return True
161-
162-
if node.target in _dont_constant_fold:
163-
return True
164149
return False
165150

166151
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:

0 commit comments

Comments
 (0)