@@ -83,6 +83,91 @@ def maybe_clone(x):
8383 return x
8484
8585
86+ # Note: [Anomaly Mode Semantics in Compiled Autograd]
87+ # In the eager autograd engine, anomaly mode is able to detect NaNs
88+ # after each node. This is useful, because the executed code with
89+ # and without anomaly mode are the same. So assuming determinism,
90+ # a NaN in regular mode should also happen in anomaly mode.
91+ #
92+ # With torch.compile, following eager semantics would require inserting
93+ # runtime asserts to check for NaNs, which could prevent some fusions.
94+ # This results in different code being run with and without anomaly mode.
95+ # So different semantics are needed, this implementation below will check
96+ # for NaNs at the end of the autograd call, instead of after each node
97+ class NaNChecker :
98+ def __init__ (self , accumulate_grad : bool ):
99+ self .accumulate_grad = accumulate_grad
100+ self .params_indices : list [int ] = []
101+ self .params_to_check : dict [str , torch .Tensor ] = {}
102+ self .output_names : list [str ] = []
103+
104+ def prep_with_graph (self , graph : torch .fx .Graph ):
105+ inputs_node = next (iter (graph .nodes ))
106+ acc_grad_nodes = graph .find_nodes (
107+ op = "call_function" , target = torch .ops .inductor .accumulate_grad_ .default
108+ )
109+ output_nodes = graph .find_nodes (op = "output" )[0 ].args [0 ]
110+ assert self .accumulate_grad == bool (
111+ acc_grad_nodes
112+ ) and self .accumulate_grad == (not output_nodes )
113+
114+ for node in acc_grad_nodes :
115+ param_node = node .args [0 ]
116+ # AccumulateGrad always saves a reference to the param
117+ # so Compiled Autograd will always lift the param and
118+ # this should always be true
119+ assert (
120+ param_node .target == operator .getitem
121+ and param_node .args [0 ] is inputs_node # type: ignore[possibly-undefined]
122+ and isinstance (param_node .args [1 ], int )
123+ )
124+ self .params_indices .append (param_node .args [1 ])
125+
126+ self .output_names = [node .name for node in output_nodes ]
127+
128+ def prep_with_inputs (self , inputs : tuple [torch .Tensor ]):
129+ if not self .accumulate_grad :
130+ # Using .grad, nothing to prep
131+ return
132+
133+ # Using .backward, we must check existing grads on params if any
134+ for idx in self .params_indices :
135+ grad = inputs [idx ].grad
136+ if grad is not None :
137+ assert not torch .isnan (grad ).any (), (
138+ f"Compiled autograd running under anomaly mode with inputs[{ idx } ] already "
139+ "having NaN gradient. This is not supported."
140+ )
141+
142+ self .params_to_check [f"inputs[{ idx } ]" ] = inputs [idx ]
143+
144+ def check (self , out : tuple [torch .Tensor ]):
145+ if self .accumulate_grad :
146+ # Using .backward, graph outputs are empty
147+ assert not out
148+ nan_params : list [str ] = []
149+ for inputs_str , param in self .params_to_check .items ():
150+ assert param .grad is not None # not true for autograd.grad
151+ if torch .isnan (param .grad ).any ():
152+ nan_params .append (inputs_str )
153+
154+ if nan_params :
155+ raise RuntimeError (
156+ f"Compiled Autograd returned NaN gradients for parameters: { ',' .join (nan_params )} ."
157+ )
158+ else :
159+ # Using .grad, graph outputs are grads
160+ nan_grads : list [str ] = []
161+ for i , grad in enumerate (out ):
162+ if torch .isnan (grad ).any ():
163+ nan_grads .append (self .output_names [i ])
164+
165+ if nan_grads :
166+ raise RuntimeError (
167+ f"Compiled Autograd returned NaN gradients for output nodes: { ',' .join (nan_grads )} ."
168+ )
169+
170+
86171# We lazily bind "functional backward" variants for PyTorch built-in autograd
87172# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0
88173# Each "functional backward" is bound the first time the node's apply_with_saved
@@ -188,12 +273,15 @@ def begin_capture(
188273 sizes : list [int ],
189274 scalars : list [Union [int , float ]],
190275 origins : list [list [tuple [int , str ]]],
276+ accumulate_grad : bool ,
277+ check_nans : bool ,
191278 ):
192279 counters ["compiled_autograd" ]["captures" ] += 1
193280 self .id = next (COMPILE_COUNTER )
194281 self .aot_id_counter : dict [int , int ] = defaultdict (int )
195282 self .compile_context = make_compile_context (self .id )
196283 self .compile_context .__enter__ ()
284+ self .nan_checker = NaNChecker (accumulate_grad ) if check_nans else None
197285 self .start_time_ns = time .time_ns ()
198286 get_chromium_event_logger ().log_event_start (
199287 "compiled_autograd" ,
@@ -830,6 +918,8 @@ def end_capture(self, outputs):
830918 # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
831919 # should prevent these ops from going into the CA graph.
832920 self .dce ()
921+ if self .nan_checker :
922+ self .nan_checker .prep_with_graph (self .fx_tracer .graph )
833923
834924 graph = self .create_graph_module (f"CompiledAutograd{ self .id } " )
835925 set_locals_to_steal (graph , ["inputs" ])
@@ -851,11 +941,17 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
851941 global in_compiled_autograd_region
852942 try :
853943 in_compiled_autograd_region = True
944+ if self .nan_checker :
945+ self .nan_checker .prep_with_inputs (inputs )
946+
854947 for i in runtime_inputs_to_move :
855948 inputs [i ] = inputs [i ].pin_memory ().cuda (non_blocking = True )
856949
857950 with _disable (), make_compile_context (self .id ):
858- return compiled_fn (inputs , sizes , scalars , hooks , packed_inputs )
951+ out = compiled_fn (inputs , sizes , scalars , hooks , packed_inputs )
952+ if self .nan_checker :
953+ self .nan_checker .check (out )
954+ return out
859955 finally :
860956 in_compiled_autograd_region = False
861957
0 commit comments