Skip to content

Commit dd9ff9f

Browse files
xmfanpytorchmergebot
authored andcommitted
[compiled autograd] add tests for bwd hooks relative firing order (pytorch#139004)
Pull Request resolved: pytorch#139004 Approved by: https://github.com/yf225 ghstack dependencies: pytorch#139003
1 parent fac7468 commit dd9ff9f

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,6 +2802,63 @@ def fn(x):
28022802
with torch._dynamo.compiled_autograd.enable(torch.compile):
28032803
out.backward()
28042804

2805+
@skipIfWindows(msg="node name demangling inconsistent on windows")
2806+
def test_backward_hook_relative_ordering_partial(self):
2807+
# test backward hooks for cases that CA matches eager
2808+
2809+
def fn():
2810+
order = []
2811+
2812+
class MyModule(nn.Module):
2813+
def __init__(self):
2814+
super().__init__()
2815+
self.linear = torch.nn.Linear(10, 10, bias=False)
2816+
2817+
def forward(self, x):
2818+
return self.linear(x)
2819+
2820+
x = torch.randn(10, 10)
2821+
module = MyModule()
2822+
2823+
def make_pre_hook(id):
2824+
return lambda _: order.append(f"pre_hook_{id}")
2825+
2826+
def make_post_hook(id):
2827+
return lambda _1, _2: order.append(f"post_hook_{id}")
2828+
2829+
count = 0
2830+
2831+
def register_hooks_on_all_nodes(nodes):
2832+
nonlocal count
2833+
for node, _ in nodes:
2834+
if node is None:
2835+
continue
2836+
count += 1
2837+
id = f"{node.name()}_{count}"
2838+
node.register_prehook(make_pre_hook(id))
2839+
node.register_hook(make_post_hook(id))
2840+
register_hooks_on_all_nodes(node.next_functions)
2841+
2842+
loss = module(x).sum()
2843+
register_hooks_on_all_nodes(((loss.grad_fn, None),))
2844+
2845+
def make_tensor_pre_hook(id):
2846+
return lambda _: order.append(f"tensor_pre_hook_{id}")
2847+
2848+
def make_post_acc_grad_hook(id):
2849+
return lambda _: order.append(f"post_acc_grad_hook_{id}")
2850+
2851+
module.linear.weight.register_hook(make_tensor_pre_hook("weight"))
2852+
2853+
module.linear.weight.register_post_accumulate_grad_hook(
2854+
make_post_acc_grad_hook("weight")
2855+
)
2856+
2857+
loss.backward()
2858+
yield tuple(order)
2859+
2860+
self.check_output_and_recompiles(fn)
2861+
28052862

28062863
def load_test_module(name):
28072864
testdir = Path(__file__).absolute().parent.parent
@@ -2993,6 +3050,7 @@ def wrap_test_class(orig_cls):
29933050
# Category: Divergence from eager
29943051
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
29953052
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
3053+
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
29963054
# Uncategorized
29973055
}
29983056

test/test_autograd.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
skipIfMps,
7575
skipIfNoLapack,
7676
skipIfTorchDynamo,
77+
skipIfWindows,
7778
slowTest,
7879
TestCase,
7980
xfailIfTorchDynamo,
@@ -4592,6 +4593,96 @@ def hook(t_):
45924593
):
45934594
t.backward()
45944595

4596+
@skipIfWindows(msg="node name demangling inconsistent on windows")
4597+
def test_backward_hook_relative_ordering(self):
4598+
order = []
4599+
4600+
class MyModule(nn.Module):
4601+
def __init__(self):
4602+
super().__init__()
4603+
self.linear = torch.nn.Linear(10, 10)
4604+
4605+
def forward(self, x):
4606+
return self.linear(x)
4607+
4608+
x = torch.randn(10, 10, requires_grad=True)
4609+
module = MyModule()
4610+
module.register_full_backward_hook(
4611+
lambda _1, _2, _3: order.append(
4612+
"module_full_backward_hook_BackwardHookFunctionBackward0"
4613+
)
4614+
)
4615+
4616+
def make_pre_hook(id):
4617+
return lambda _: order.append(f"pre_hook_{id}")
4618+
4619+
def make_post_hook(id):
4620+
return lambda _1, _2: order.append(f"post_hook_{id}")
4621+
4622+
count = 0
4623+
4624+
def register_hooks_on_all_nodes(nodes):
4625+
nonlocal count
4626+
for node, _ in nodes:
4627+
count += 1
4628+
id = f"{node.name()}_{count}"
4629+
node.register_prehook(make_pre_hook(id))
4630+
node.register_hook(make_post_hook(id))
4631+
register_hooks_on_all_nodes(node.next_functions)
4632+
4633+
loss = module(x).sum()
4634+
register_hooks_on_all_nodes(((loss.grad_fn, None),))
4635+
4636+
def make_tensor_pre_hook(id):
4637+
return lambda _: order.append(f"tensor_pre_hook_{id}")
4638+
4639+
def make_post_acc_grad_hook(id):
4640+
return lambda _: order.append(f"post_acc_grad_hook_{id}")
4641+
4642+
x.register_hook(make_tensor_pre_hook("x"))
4643+
module.linear.weight.register_hook(make_tensor_pre_hook("weight"))
4644+
module.linear.bias.register_hook(make_tensor_pre_hook("bias"))
4645+
4646+
x.register_post_accumulate_grad_hook(make_post_acc_grad_hook("x"))
4647+
module.linear.weight.register_post_accumulate_grad_hook(
4648+
make_post_acc_grad_hook("weight")
4649+
)
4650+
module.linear.bias.register_post_accumulate_grad_hook(
4651+
make_post_acc_grad_hook("bias")
4652+
)
4653+
4654+
loss.backward()
4655+
4656+
expected_order = [
4657+
"pre_hook_SumBackward0_1",
4658+
"post_hook_SumBackward0_1",
4659+
"pre_hook_BackwardHookFunctionBackward_2",
4660+
"post_hook_BackwardHookFunctionBackward_2",
4661+
"pre_hook_AddmmBackward0_3",
4662+
"post_hook_AddmmBackward0_3",
4663+
"tensor_pre_hook_bias",
4664+
"pre_hook_torch::autograd::AccumulateGrad_4",
4665+
"post_acc_grad_hook_bias",
4666+
"post_hook_torch::autograd::AccumulateGrad_4",
4667+
"pre_hook_TBackward0_7",
4668+
"post_hook_TBackward0_7",
4669+
"tensor_pre_hook_weight",
4670+
"pre_hook_torch::autograd::AccumulateGrad_8",
4671+
"post_acc_grad_hook_weight",
4672+
"post_hook_torch::autograd::AccumulateGrad_8",
4673+
"pre_hook_BackwardHookFunctionBackward_5",
4674+
"module_full_backward_hook_BackwardHookFunctionBackward0",
4675+
"post_hook_BackwardHookFunctionBackward_5",
4676+
"tensor_pre_hook_x",
4677+
"pre_hook_torch::autograd::AccumulateGrad_6",
4678+
"post_acc_grad_hook_x",
4679+
"post_hook_torch::autograd::AccumulateGrad_6",
4680+
]
4681+
4682+
self.assertEqual(len(expected_order), len(order))
4683+
for expected, actual in zip(expected_order, order):
4684+
self.assertEqual(expected, actual)
4685+
45954686
def test_view_replay_enabled(self):
45964687
def f(x):
45974688
out = x.clone().view(-1)

0 commit comments

Comments
 (0)