Skip to content

Commit d95c148

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add async nan check utils (#965)
Summary: Pull Request resolved: #965 Reviewed By: galrotem Differential Revision: D68530393 fbshipit-source-id: 8afdee2dc74a28b19c0a16eebf0b585f333fdc12
1 parent 06e6207 commit d95c148

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

tests/utils/test_nan.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
12+
import torch
13+
14+
from torchtnt.utils.nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph
15+
16+
17+
class NaNFunction(torch.autograd.Function):
18+
@staticmethod
19+
# pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently
20+
def forward(ctx, input):
21+
return input.clone()
22+
23+
@staticmethod
24+
# pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently
25+
def backward(ctx, grad_output):
26+
return torch.tensor([float("nan")], device="cpu")
27+
28+
29+
class NanHookTest(unittest.TestCase):
30+
def test_register_nan_hooks_on_whole_graph(self) -> None:
31+
x = torch.tensor([1.0], device="cpu", requires_grad=True)
32+
out = NaNFunction.apply(x)
33+
34+
# no error is thrown
35+
out.backward()
36+
37+
_ = register_nan_hooks_on_whole_graph([out])
38+
with self.assertRaisesRegex(RuntimeError, "Detected NaN"):
39+
out.backward()
40+
41+
def test_check_for_nan_or_inf(self) -> None:
42+
tensor = torch.tensor([float("nan")], device="cpu")
43+
44+
with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"):
45+
check_for_nan_or_inf(tensor)
46+
47+
tensor = torch.tensor([float("inf")], device="cpu")
48+
with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"):
49+
check_for_nan_or_inf(tensor)

torchtnt/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ModuleSummary,
5252
prune_module_summary,
5353
)
54+
from .nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph
5455
from .oom import (
5556
attach_oom_observer,
5657
is_out_of_cpu_memory,
@@ -89,6 +90,8 @@
8990
)
9091

9192
__all__ = [
93+
"check_for_nan_or_inf",
94+
"register_nan_hooks_on_whole_graph",
9295
"IsNaNEvaluator",
9396
"ThresholdEvaluator",
9497
"CheckpointPath",

torchtnt/utils/nan.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from collections import deque
8+
from typing import Callable, Iterator, List, Optional, Sequence, Union
9+
10+
import torch
11+
from pyre_extensions import none_throws
12+
from torch.autograd.graph import GradientEdge, Node
13+
from torch.utils.hooks import RemovableHandle
14+
15+
16+
def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, GradientEdge]) -> Node:
17+
if isinstance(t, torch.Tensor):
18+
return none_throws(t.grad_fn)
19+
else:
20+
# pyre-ignore Undefined attribute [16]: `GradientEdge` has no attribute `function`.
21+
return t.function if t is not None else None
22+
23+
24+
def register_nan_hooks_on_whole_graph( # noqa: C901
25+
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]]
26+
) -> Callable[[], None]:
27+
"""
28+
Registers a nan hook on the whole graph of the given tensors. The hook will throw error if a nan is detected.
29+
30+
This is useful if you want training to halt when a nan is detected during autograd process (ie loss is inf or nan).
31+
32+
Usage:
33+
34+
>>> class NaNFunction(torch.autograd.Function):
35+
@staticmethod
36+
def forward(ctx, input):
37+
return input.clone()
38+
39+
@staticmethod
40+
def backward(ctx, grad_output):
41+
return torch.tensor([float("nan")], device="cpu")
42+
>>> x = torch.tensor([1.0], device="cpu", requires_grad=True)
43+
>>> out = NaNFunction.apply(x)
44+
>>> _ = register_nan_hooks_on_whole_graph([out])
45+
>>> out.backward()
46+
RuntimeError: Detected NaN in 'grad_inputs[0]' after executing Node
47+
48+
"""
49+
50+
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
51+
52+
def iter_graph(roots: List[torch.autograd.graph.Node]) -> Iterator[Node]:
53+
if not roots:
54+
return
55+
seen = set()
56+
q = deque()
57+
for node in roots:
58+
if node is not None and node not in seen:
59+
seen.add(node)
60+
q.append(node)
61+
while q:
62+
node = q.popleft()
63+
for fn, _ in node.next_functions:
64+
if fn is None or fn in seen:
65+
continue
66+
seen.add(fn)
67+
q.append(fn)
68+
yield node
69+
70+
def _assert_no_nan_tensor(t: Optional[torch.Tensor], msg: str) -> None:
71+
if t is not None:
72+
torch._assert_async(torch.logical_not(torch.any(torch.isnan(t))), msg)
73+
74+
def posthook(
75+
grad_inputs: Sequence[Optional[torch.Tensor]],
76+
grad_outputs: Sequence[Optional[torch.Tensor]],
77+
) -> None:
78+
node = torch._C._current_autograd_node()
79+
for i, g_in in enumerate(grad_inputs):
80+
_assert_no_nan_tensor(
81+
g_in, f"Detected NaN in 'grad_inputs[{i}]' after executing Node: {node}"
82+
)
83+
84+
handles: List[RemovableHandle] = []
85+
for node in iter_graph(grad_fns):
86+
posthandle = node.register_hook(posthook)
87+
handles.append(posthandle)
88+
89+
def unregister_hooks() -> None:
90+
for handle in handles:
91+
handle.remove()
92+
93+
return unregister_hooks
94+
95+
96+
def check_for_nan_or_inf(
97+
tensor: torch.Tensor, msg: str = "Detected NaN or Inf in tensor"
98+
) -> None:
99+
"""
100+
Asynchronously assert that the tensor is neither NaN nor infinity. This will
101+
produce a cuda device side assert error if tensor on gpu.
102+
"""
103+
104+
torch._assert_async(
105+
torch.logical_not(torch.any(torch.isnan(tensor) | torch.isinf(tensor))),
106+
msg,
107+
)

0 commit comments

Comments
 (0)