Skip to content

Commit c8c892b

Browse files
ydwu4pytorchmergebot
authored andcommitted
[scan] disable functionalization key in backward tracing (pytorch#154343)
Previously, we didn't disable functionalization key when materializing backward graph. This causes the torch.zeros_like call for the case where grad is None to return a functional tensor that's not tracked by the proxy tensor mode. This PR fixes it by putting the tracing code under disable functionalization ctx manager. Fixes pytorch#153437. Pull Request resolved: pytorch#154343 Approved by: https://github.com/zou3519
1 parent 5e93abe commit c8c892b

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

test/inductor/test_control_flow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,18 @@ def accumulate_chunk(input_chunk, target_chunk):
16421642
torch.cat(grad_inputs, dim=0) / chunks,
16431643
)
16441644

1645+
class ScanWithClamp(torch.nn.Module):
1646+
def __init__(self):
1647+
super().__init__()
1648+
1649+
def forward(self, scan_op, initial, xs):
1650+
def step(h_prev, x_t):
1651+
h_next = (h_prev + x_t).clamp(min=0.1)
1652+
return h_next, h_next.clone()
1653+
1654+
final, ys = scan_op(step, initial, xs)
1655+
return final, ys
1656+
16451657

16461658
class ScanTests(TestCase):
16471659
def _run_test(
@@ -1824,6 +1836,24 @@ def test_scan_compare_chunked_ce_with_no_scan(self, device, dynamic):
18241836
device=device,
18251837
)
18261838

1839+
@requires_gpu
1840+
@parametrize("device", ["cpu", GPU_TYPE])
1841+
@parametrize("dynamic", [True, False])
1842+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
1843+
def test_scan_with_clamp(self, device, dynamic):
1844+
B = 4
1845+
T = 8
1846+
H = 16
1847+
self._run_test(
1848+
model=ScanModels.ScanWithClamp(),
1849+
inputs=(
1850+
torch.randn((B, H)),
1851+
torch.randn((T, B, H), requires_grad=True),
1852+
),
1853+
device=device,
1854+
dynamic=dynamic,
1855+
)
1856+
18271857

18281858
class MapModels:
18291859
class Simple(torch.nn.Module):

torch/_higher_order_ops/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,13 +1007,13 @@ def _materialize_as_graph_inner():
10071007
with suspend_functionalization(), disable_functional_mode():
10081008
with disable_proxy_modes_tracing():
10091009
unfunc_t = [_from_fun(arg) for arg in args]
1010-
with contextlib.ExitStack() as stack:
1011-
stack.enter_context(
1012-
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
1013-
)
1014-
if force_enable_grad:
1015-
stack.enter_context(torch.enable_grad())
1016-
return _maybe_reenter_make_fx(fn)(*unfunc_t)
1010+
with contextlib.ExitStack() as stack:
1011+
stack.enter_context(
1012+
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
1013+
)
1014+
if force_enable_grad:
1015+
stack.enter_context(torch.enable_grad())
1016+
return _maybe_reenter_make_fx(fn)(*unfunc_t)
10171017

10181018
gm = _materialize_as_graph_inner()
10191019
assert gm is not None

0 commit comments

Comments
 (0)