@@ -264,6 +264,10 @@ def backward(ctx, grad_output):
264264 self.assertExpected(x_grad_desc, "x_grad_desc")
265265 self.assertExpected(y_grad_desc, "y_grad_desc")
266266
267+ # Avoid leaking memory
268+ x.grad = None
269+ y.grad = None
270+
267271 def test_once_differentiable(self):
268272 class MyFunction(Function):
269273 @staticmethod
@@ -293,6 +297,10 @@ def backward(ctx, grad_output):
293297 "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
294298 )
295299
300+ # Avoid leaking memory
301+ x.grad = None
302+ y.grad = None
303+
296304 def test_function_returns_input(self):
297305 class MyFunction(Function):
298306 @staticmethod
@@ -640,8 +648,8 @@ def fn(x):
640648 for g in should_not_execute:
641649 self.assertFalse(torch._C._will_engine_execute_node(g))
642650
643- b.register_hook(fn)
644- c.register_hook(fn)
651+ h1 = b.register_hook(fn)
652+ h2 = c.register_hook(fn)
645653
646654 # .backward(inputs=) is OK
647655 out = c.sum()
@@ -668,7 +676,7 @@ def fn(x):
668676 counter[0] += 1
669677 self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
670678
671- b.register_hook(fn)
679+ h3 = b.register_hook(fn)
672680 counter[0] = 0
673681 torch.autograd.grad(b.sum(), (a,))
674682 self.assertEqual(counter[0], 1)
@@ -680,6 +688,11 @@ def fn(x):
680688 with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
681689 torch._C._will_engine_execute_node(out)
682690
691+ # Ensure we don't leak memory
692+ h1.remove()
693+ h2.remove()
694+ h3.remove()
695+
683696 def test_custom_function_vmap_defaults(self):
684697 class MySquare(Function):
685698 @staticmethod
@@ -899,6 +912,10 @@ def test_hessian_vector(self):
899912 self.assertEqual(x.grad, x_grad + x_hv)
900913 self.assertEqual(y.grad, y_grad + y_hv)
901914
915+ # Avoid leaking memory
916+ x.grad = None
917+ y.grad = None
918+
902919 def test_grad(self):
903920 x = torch.randn(2, 2, requires_grad=True)
904921 y = torch.randn(2, 2, requires_grad=True)
@@ -924,6 +941,10 @@ def test_grad(self):
924941 self.assertEqual(x.grad, x_grad)
925942 self.assertEqual(y.grad, y_grad)
926943
944+ # Avoid leaking memory
945+ x.grad = None
946+ y.grad = None
947+
927948 # Test that grad_outputs and outputs have the same shape
928949 grad_out = torch.ones(2)
929950 try:
@@ -1071,6 +1092,7 @@ def test_grad_fn_input_metadata(self):
10711092 layout=torch.jagged,
10721093 requires_grad=True,
10731094 )
1095+
10741096 nt_metadata = nt.clone().grad_fn._input_metadata[0]
10751097
10761098 self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
@@ -2209,16 +2231,21 @@ def fn2(grad):
22092231
22102232 b = torch.rand(3, 3, requires_grad=True)
22112233 out1, out2 = fn(b)
2212- out1.register_hook(fn0)
2213- out2.register_hook(fn1)
2234+ h1 = out1.register_hook(fn0)
2235+ h2 = out2.register_hook(fn1)
22142236 # node refers to two hook dicts
22152237 # out1 no longer no longer points to its old hook dict
22162238 out1.mul_(2)
22172239 # fn2 is registered to out1's new hook dict
2218- out1.register_hook(fn2)
2240+ h3 = out1.register_hook(fn2)
22192241 (out1 + out2 * 3).sum().backward()
22202242 self.assertEqual(counts, [1, 1, 1])
22212243
2244+ # Avoid leaking memory
2245+ h1.remove()
2246+ h2.remove()
2247+ h3.remove()
2248+
22222249 def test_tensor_hooks_inplace_over_view(self):
22232250 # There might be a better UX here, but this is the way it is now
22242251 count = [0]
@@ -2484,6 +2511,11 @@ def test_backward_with_nonleaf_inputs(self):
24842511 )
24852512 self.assertIsNone(z.grad)
24862513
2514+ # Avoid leaking memory
2515+ x.grad = None
2516+ y.grad = None
2517+ x_nonleaf.grad = None
2518+
24872519 def test_dependent_backward(self):
24882520 x = torch.randn(10, requires_grad=True)
24892521 y = x**2
@@ -4445,6 +4477,7 @@ def hook(_):
44454477
44464478 def test_current_graph_task_execution_order(self):
44474479 predicted = [None]
4480+ all_hooks = []
44484481
44494482 def hook(_):
44504483 predicted[0] = torch._C._current_graph_task_execution_order()
@@ -4473,11 +4506,11 @@ def hook(t_):
44734506 return hook
44744507
44754508 for i, t in enumerate(tensors):
4476- t.register_hook(get_hook(i))
4509+ all_hooks.append( t.register_hook(get_hook(i) ))
44774510
44784511 # Basic example: single path
44794512 t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4480- t.register_hook(hook)
4513+ all_hooks.append( t.register_hook(hook) )
44814514 with torch.autograd.set_multithreading_enabled(False):
44824515 t.backward()
44834516 self.assertExpectedInline(
@@ -4494,7 +4527,7 @@ def hook(t_):
44944527 d = a.cos()
44954528 out = c * d
44964529 register_logging_hooks(a, b, c, d, out)
4497- out.register_hook(hook)
4530+ all_hooks.append( out.register_hook(hook) )
44984531 with torch.autograd.set_multithreading_enabled(False):
44994532 out.backward()
45004533 self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4506,7 +4539,7 @@ def hook(t_):
45064539 c = a.cos()
45074540 out = b * c
45084541 register_logging_hooks(a, b, c, out)
4509- out.register_hook(hook)
4542+ all_hooks.append( out.register_hook(hook) )
45104543 with torch.autograd.set_multithreading_enabled(False):
45114544 out.backward()
45124545 self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4519,7 +4552,7 @@ def hook(t_):
45194552 out2 = b.cos()
45204553 out3 = b.cos()
45214554 register_logging_hooks(a, b, out, out2, out3)
4522- out3.register_hook(hook)
4555+ all_hooks.append( out3.register_hook(hook) )
45234556 with torch.autograd.set_multithreading_enabled(False):
45244557 torch.autograd.grad((out, out3, out2), inputs=(a,))
45254558 self.assertExpectedInline(
@@ -4537,7 +4570,7 @@ def hook(t_):
45374570 b = a * 2
45384571 out = b.sin()
45394572 register_logging_hooks(a, b, out)
4540- out.register_hook(hook)
4573+ all_hooks.append( out.register_hook(hook) )
45414574 with torch.autograd.set_multithreading_enabled(False):
45424575 out.backward()
45434576 self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4548,7 +4581,7 @@ def hook(t_):
45484581 b = a * 2
45494582 out = b.sin()
45504583 register_logging_hooks(a, b, out)
4551- out.register_hook(hook)
4584+ all_hooks.append( out.register_hook(hook) )
45524585 with torch.autograd.set_multithreading_enabled(False):
45534586 torch.autograd.grad((out,), inputs=(a, b))
45544587 self.assertEqual(
@@ -4567,7 +4600,7 @@ def hook(t_):
45674600 c = a * b
45684601 out = c.sin()
45694602 register_logging_hooks(a, b, c, out)
4570- out.register_hook(hook)
4603+ all_hooks.append( out.register_hook(hook) )
45714604 with torch.autograd.set_multithreading_enabled(False):
45724605 torch.autograd.grad((out,), inputs=(a,))
45734606 self.assertEqual(
@@ -4588,13 +4621,17 @@ def hook(t_):
45884621
45894622 # Errors when context manager not enabled
45904623 t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4591- t.register_hook(hook)
4624+ all_hooks.append( t.register_hook(hook) )
45924625 with self.assertRaisesRegex(
45934626 RuntimeError,
45944627 "expects the current backward to be executed with multithreading disabled",
45954628 ):
45964629 t.backward()
45974630
4631+ # Avoid leaking memory
4632+ for h in all_hooks:
4633+ h.remove()
4634+
45984635 @skipIfWindows(msg="node name demangling inconsistent on windows")
45994636 def test_backward_hook_relative_ordering(self):
46004637 order = []
@@ -12927,7 +12964,7 @@ def hook(grads):
1292712964 else:
1292812965 self.assertEqual(res, grad_is_none)
1292912966
12930- torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
12967+ handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
1293112968
1293212969 out = (t2 * t3).sum()
1293312970
@@ -12976,6 +13013,8 @@ def backward_retain_graph(out, t2, t3):
1297613013 self.assertEqual(err_count[0], 1)
1297713014 self.assertEqual(res, [False, True, True, False])
1297813015
13016+ handle.remove()
13017+
1297913018 def test_multi_grad_any_hooks(self):
1298013019 # Multihooks should behave independently per execution of backward
1298113020 # Test that the hook fired the number of times we ran backward
0 commit comments