@@ -305,6 +305,62 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
305305
306306
307307def get_output_metadata (subgraph , * operands ):
308+ """
309+ Extract metadata about the subgraph outputs WITHOUT executing the subgraph.
310+ This avoids running side-effectful operations twice (once here, once in forward).
311+ We analyze the graph structure statically to extract metadata.
312+ """
313+ # Unwrap FunctionalizeCtxWrapper if present
314+ if isinstance (subgraph , FunctionalizeCtxWrapper ):
315+ subgraph = subgraph .subgraph
316+
317+ # If not a GraphModule, fall back to execution-based metadata extraction
318+ if not isinstance (subgraph , torch .fx .GraphModule ):
319+ return _get_output_metadata_by_execution (subgraph , * operands )
320+
321+ output_metadata = OutputMetadata ()
322+
323+ # Extract output arguments from the output node
324+ # The output node has args=(output_values,) where output_values is a tuple/list
325+ output_node = next (reversed (subgraph .graph .find_nodes (op = "output" )))
326+ output_metadata .num_fw_outs = len (output_node .args [0 ])
327+
328+ for idx , output_arg in enumerate (output_node .args [0 ]):
329+ if not isinstance (output_arg , torch .fx .Node ):
330+ if isinstance (output_arg , int ):
331+ output_metadata .indexes_with_symint .add (idx )
332+ output_metadata .indexes_with_no_grad .add (idx )
333+ continue
334+
335+ # Check node metadata for type information
336+ if output_arg .meta .get ("val" ) is None :
337+ # If we don't have complete metadata for all outputs, fall back to execution
338+ # This is important for correctness (e.g., detecting SymInts) even though it
339+ # runs side-effectful operations
340+ return _get_output_metadata_by_execution (subgraph , * operands )
341+
342+ val = output_arg .meta ["val" ]
343+ if isinstance (val , torch .SymInt ):
344+ output_metadata .indexes_with_symint .add (idx )
345+ output_metadata .indexes_with_no_grad .add (idx )
346+ elif isinstance (val , torch .Tensor ):
347+ # Check if tensor requires grad from metadata
348+ if hasattr (val , "requires_grad" ) and not val .requires_grad :
349+ output_metadata .indexes_with_no_grad .add (idx )
350+ else :
351+ # Non-tensor, non-symint (shouldn't happen but be safe)
352+ output_metadata .indexes_with_no_grad .add (idx )
353+
354+ return output_metadata
355+
356+
357+ def _get_output_metadata_by_execution (subgraph , * operands ):
358+ """
359+ Fallback: Extract metadata by executing the subgraph.
360+ This should only be used when static analysis fails.
361+ WARNING: This will run side-effectful operations!
362+ """
363+
308364 with suspend_functionalization (), disable_functional_mode ():
309365 with disable_proxy_modes_tracing ():
310366 # args are functional tensors, generate some example tensors
@@ -324,19 +380,15 @@ def get_output_metadata(subgraph, *operands):
324380
325381 num_fw_outs = len (fw_outs )
326382
327- # Collect the indexes of none in the output to check that the grad
328- # is None at the corresponding index in the backward. This check is
329- # performed in the autograd.Function - InvokeSubgraphAutogradOp.
330- # Also collect the indexes of no_grad in the output to filter out
331- # the grad_outs in the `backward` method.
332383 output_metadata = OutputMetadata ()
333-
334384 output_metadata .num_fw_outs = num_fw_outs
385+
335386 for idx , fw_out in enumerate (fw_outs ):
336387 if isinstance (fw_out , torch .SymInt ):
337388 output_metadata .indexes_with_symint .add (idx )
338389 elif not fw_out .requires_grad :
339390 output_metadata .indexes_with_no_grad .add (idx )
391+
340392 return output_metadata
341393
342394
0 commit comments