Skip to content

Commit dcb378c

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] support anomly mode nan checks with different semantics than eager (pytorch#149897)
see note in code Pull Request resolved: pytorch#149897 Approved by: https://github.com/jansel ghstack dependencies: pytorch#149647, pytorch#149709, pytorch#149651
1 parent 488b87c commit dcb378c

File tree

4 files changed

+176
-10
lines changed

4 files changed

+176
-10
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,12 @@ def test_inputs_aliasing_bytecode_attr_mutations(self):
878878
activ = torch.ones(100) * 2
879879
inputs = [param, activ]
880880
_, proxies, _, _ = compiler.begin_capture(
881-
inputs=inputs, sizes=[], scalars=[], origins=[[], [], []]
881+
inputs=inputs,
882+
sizes=[],
883+
scalars=[],
884+
origins=[[], [], []],
885+
accumulate_grad=False,
886+
check_nans=False,
882887
)
883888
param_proxy, activ_proxy = proxies
884889
buf = activ_proxy * 2
@@ -3971,6 +3976,68 @@ def fn(x, y):
39713976
with compiled_autograd._enable(lambda gm: gm):
39723977
loss.backward()
39733978

3979+
def test_anomaly_mode_already_nan(self):
3980+
def fn():
3981+
with torch.autograd.detect_anomaly():
3982+
a = torch.randn(5, 5, requires_grad=True)
3983+
a.grad = torch.full((5, 5), float("nan"))
3984+
b = torch.randn(5, 5)
3985+
out = torch.matmul(a, b)
3986+
loss = out.sum()
3987+
with torch._dynamo.compiled_autograd._enable(lambda gm: gm):
3988+
loss.backward()
3989+
3990+
with self.assertRaisesRegex(
3991+
AssertionError, "already having NaN gradient. This is not supported."
3992+
):
3993+
fn()
3994+
3995+
def test_anomaly_mode_backward(self):
3996+
def fn():
3997+
class MyFn(torch.autograd.Function):
3998+
@staticmethod
3999+
def forward(ctx, x):
4000+
return x
4001+
4002+
@staticmethod
4003+
def backward(ctx, gO):
4004+
return torch.full(gO.size(), float("nan"))
4005+
4006+
with torch.autograd.detect_anomaly():
4007+
a = torch.randn(5, 5, requires_grad=True)
4008+
out = MyFn.apply(a)
4009+
loss = out.sum()
4010+
with torch._dynamo.compiled_autograd._enable(lambda gm: gm):
4011+
loss.backward()
4012+
4013+
with self.assertRaisesRegex(
4014+
RuntimeError, "Compiled Autograd returned NaN gradients for parameters"
4015+
):
4016+
fn()
4017+
4018+
def test_anomaly_mode_grad(self):
4019+
def fn():
4020+
class MyFn(torch.autograd.Function):
4021+
@staticmethod
4022+
def forward(ctx, x):
4023+
return x
4024+
4025+
@staticmethod
4026+
def backward(ctx, gO):
4027+
return torch.full(gO.size(), float("nan"))
4028+
4029+
with torch.autograd.detect_anomaly():
4030+
a = torch.randn(5, 5, requires_grad=True)
4031+
out = MyFn.apply(a)
4032+
loss = out.sum()
4033+
with torch._dynamo.compiled_autograd._enable(lambda gm: gm):
4034+
torch.autograd.grad(loss, inputs=a)
4035+
4036+
with self.assertRaisesRegex(
4037+
RuntimeError, "Compiled Autograd returned NaN gradients for output nodes"
4038+
):
4039+
fn()
4040+
39744041

39754042
def load_test_module(name):
39764043
testdir = Path(__file__).absolute().parent.parent
@@ -4103,7 +4170,6 @@ def wrap_test_class(orig_cls):
41034170
"test_reentrant_with_callbacks_depth_0", # queue_callback
41044171
"test_reentrant_with_callbacks_depth_1", # queue_callback
41054172
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
4106-
"test_anomaly_grad_warnings", # does not support anomaly mode
41074173
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
41084174
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
41094175
"test_post_accumulate_grad_hook_ordering", # accuracy error
@@ -4114,7 +4180,6 @@ def wrap_test_class(orig_cls):
41144180
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
41154181
"test_accumulate_grad", # create_graph
41164182
"test_anomaly_assign_parent_cleanup", # create_graph
4117-
"test_anomaly_mode_no_check_nan", # anomaly mode
41184183
"test_backward_create_graph_warns", # create_graph
41194184
"test_backward_with_nonleaf_inputs", # create_graph
41204185
"test_create_graph_and_full_backward_hook_cycle", # create_graph
@@ -4146,7 +4211,6 @@ def wrap_test_class(orig_cls):
41464211
"test_select_sum", # create_graph, also needs graph breaks
41474212
"test_will_engine_execute_node", # retains_grad_hooks
41484213
"test_backward_to_node", # retains_grad_hooks NYI
4149-
"test_anomaly_detect_nan", # anomaly mode
41504214
"test_custom_autograd_no_early_free", # create_graph
41514215
"test_custom_function_error", # vjp
41524216
"test_custom_function_save_for_forward", # vjp
@@ -4202,6 +4266,9 @@ def wrap_test_class(orig_cls):
42024266
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
42034267
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
42044268
"test_checkpointing_without_reentrant_custom_function_works", # ctx.saved_tensors are cached by CA
4269+
"test_anomaly_mode_no_check_nan", # different error messages
4270+
"test_anomaly_grad_warnings", # different error messages
4271+
"test_anomaly_detect_nan", # fake tensor errors on NaN
42054272
# Uncategorized
42064273
"test_not_implemented_grad", # Dynamo changes the types of exceptions
42074274
}

torch/_dynamo/compiled_autograd.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,91 @@ def maybe_clone(x):
8383
return x
8484

8585

86+
# Note: [Anomaly Mode Semantics in Compiled Autograd]
87+
# In the eager autograd engine, anomaly mode is able to detect NaNs
88+
# after each node. This is useful, because the executed code with
89+
# and without anomaly mode are the same. So assuming determinism,
90+
# a NaN in regular mode should also happen in anomaly mode.
91+
#
92+
# With torch.compile, following eager semantics would require inserting
93+
# runtime asserts to check for NaNs, which could prevent some fusions.
94+
# This results in different code being run with and without anomaly mode.
95+
# So different semantics are needed, this implementation below will check
96+
# for NaNs at the end of the autograd call, instead of after each node
97+
class NaNChecker:
98+
def __init__(self, accumulate_grad: bool):
99+
self.accumulate_grad = accumulate_grad
100+
self.params_indices: list[int] = []
101+
self.params_to_check: dict[str, torch.Tensor] = {}
102+
self.output_names: list[str] = []
103+
104+
def prep_with_graph(self, graph: torch.fx.Graph):
105+
inputs_node = next(iter(graph.nodes))
106+
acc_grad_nodes = graph.find_nodes(
107+
op="call_function", target=torch.ops.inductor.accumulate_grad_.default
108+
)
109+
output_nodes = graph.find_nodes(op="output")[0].args[0]
110+
assert self.accumulate_grad == bool(
111+
acc_grad_nodes
112+
) and self.accumulate_grad == (not output_nodes)
113+
114+
for node in acc_grad_nodes:
115+
param_node = node.args[0]
116+
# AccumulateGrad always saves a reference to the param
117+
# so Compiled Autograd will always lift the param and
118+
# this should always be true
119+
assert (
120+
param_node.target == operator.getitem
121+
and param_node.args[0] is inputs_node # type: ignore[possibly-undefined]
122+
and isinstance(param_node.args[1], int)
123+
)
124+
self.params_indices.append(param_node.args[1])
125+
126+
self.output_names = [node.name for node in output_nodes]
127+
128+
def prep_with_inputs(self, inputs: tuple[torch.Tensor]):
129+
if not self.accumulate_grad:
130+
# Using .grad, nothing to prep
131+
return
132+
133+
# Using .backward, we must check existing grads on params if any
134+
for idx in self.params_indices:
135+
grad = inputs[idx].grad
136+
if grad is not None:
137+
assert not torch.isnan(grad).any(), (
138+
f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
139+
"having NaN gradient. This is not supported."
140+
)
141+
142+
self.params_to_check[f"inputs[{idx}]"] = inputs[idx]
143+
144+
def check(self, out: tuple[torch.Tensor]):
145+
if self.accumulate_grad:
146+
# Using .backward, graph outputs are empty
147+
assert not out
148+
nan_params: list[str] = []
149+
for inputs_str, param in self.params_to_check.items():
150+
assert param.grad is not None # not true for autograd.grad
151+
if torch.isnan(param.grad).any():
152+
nan_params.append(inputs_str)
153+
154+
if nan_params:
155+
raise RuntimeError(
156+
f"Compiled Autograd returned NaN gradients for parameters: {','.join(nan_params)}."
157+
)
158+
else:
159+
# Using .grad, graph outputs are grads
160+
nan_grads: list[str] = []
161+
for i, grad in enumerate(out):
162+
if torch.isnan(grad).any():
163+
nan_grads.append(self.output_names[i])
164+
165+
if nan_grads:
166+
raise RuntimeError(
167+
f"Compiled Autograd returned NaN gradients for output nodes: {','.join(nan_grads)}."
168+
)
169+
170+
86171
# We lazily bind "functional backward" variants for PyTorch built-in autograd
87172
# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0
88173
# Each "functional backward" is bound the first time the node's apply_with_saved
@@ -188,12 +273,15 @@ def begin_capture(
188273
sizes: list[int],
189274
scalars: list[Union[int, float]],
190275
origins: list[list[tuple[int, str]]],
276+
accumulate_grad: bool,
277+
check_nans: bool,
191278
):
192279
counters["compiled_autograd"]["captures"] += 1
193280
self.id = next(COMPILE_COUNTER)
194281
self.aot_id_counter: dict[int, int] = defaultdict(int)
195282
self.compile_context = make_compile_context(self.id)
196283
self.compile_context.__enter__()
284+
self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None
197285
self.start_time_ns = time.time_ns()
198286
get_chromium_event_logger().log_event_start(
199287
"compiled_autograd",
@@ -830,6 +918,8 @@ def end_capture(self, outputs):
830918
# Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
831919
# should prevent these ops from going into the CA graph.
832920
self.dce()
921+
if self.nan_checker:
922+
self.nan_checker.prep_with_graph(self.fx_tracer.graph)
833923

834924
graph = self.create_graph_module(f"CompiledAutograd{self.id}")
835925
set_locals_to_steal(graph, ["inputs"])
@@ -851,11 +941,17 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
851941
global in_compiled_autograd_region
852942
try:
853943
in_compiled_autograd_region = True
944+
if self.nan_checker:
945+
self.nan_checker.prep_with_inputs(inputs)
946+
854947
for i in runtime_inputs_to_move:
855948
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
856949

857950
with _disable(), make_compile_context(self.id):
858-
return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
951+
out = compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
952+
if self.nan_checker:
953+
self.nan_checker.check(out)
954+
return out
859955
finally:
860956
in_compiled_autograd_region = False
861957

torch/csrc/autograd/engine.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,9 +1330,6 @@ auto Engine::execute(
13301330
TORCH_CHECK(
13311331
!create_graph, "compiled_autograd does not support create_graph");
13321332
_thread_check.release();
1333-
TORCH_CHECK(
1334-
!AnomalyMode::is_enabled(),
1335-
"compiled_autograd does not support AnomalyMode")
13361333
GraphTaskGuard guard(graph_task);
13371334
CheckpointValidGuard cpvguard(graph_task);
13381335
return (*compiled_autograd)(

torch/csrc/dynamo/python_compiled_autograd.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,9 @@ static TraceState call_begin_capture(
730730
CacheNode& cache,
731731
AutogradCompilerCall& compiler_call,
732732
size_t num_outputs,
733-
std::optional<std::string>&& maybe_compile_reason) {
733+
std::optional<std::string>&& maybe_compile_reason,
734+
bool accumulate_grad,
735+
bool check_nans) {
734736
static PyObject* method_name = PyUnicode_InternFromString("begin_capture");
735737
THPObjectPtr py_input(THPVariable_WrapList(compiler_call.tensor_args.inputs));
736738
THPObjectPtr py_size_input(cache.wrap_dynamic_inputs());
@@ -745,6 +747,8 @@ static TraceState call_begin_capture(
745747
py_size_input.get(),
746748
py_ivalue_args_input.get(),
747749
py_node_origins.get(),
750+
PyBool_FromLong(accumulate_grad),
751+
PyBool_FromLong(check_nans),
748752
nullptr)));
749753

750754
PyObject *compile_id_str{nullptr}, *fake_inputs{nullptr},
@@ -914,7 +918,9 @@ static CacheNode* _compiled_autograd_impl(
914918
*cache,
915919
compiler_call,
916920
output_edges.size(),
917-
std::move(compile_reason));
921+
std::move(compile_reason),
922+
accumulate_grad,
923+
AnomalyMode::is_enabled() && AnomalyMode::should_check_nan());
918924
InputBuffers input_buffers;
919925

920926
for (size_t i = 0; i < ordered_calls.size(); i++) {

0 commit comments

Comments
 (0)