Skip to content

Commit 789240b

Browse files
angelayipytorchmergebot
authored andcommitted
[invoke_subgraph] Don't run the graph twice when autograd enabled (pytorch#167245)
In the [previous PR](https://github.com/pytorch/pytorch/pull/167231/files#diff-e2b74af5d8b538a7d07d18507d27010703742ddad5f819992b55f5abc6d9a502R964-R966) we found that the autograd eager impl of invoke_subgraph calls the subgraph twice. If the subgraph contains effects then effects will be run twice, which is bad. This PR fixes the issue by getting the output metadata from `subgraph`'s `node.meta` if it exists. Differential Revision: [D87392740](https://our.internmc.facebook.com/intern/diff/D87392740) Pull Request resolved: pytorch#167245 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#167231
1 parent f49833d commit 789240b

File tree

2 files changed

+59
-11
lines changed

2 files changed

+59
-11
lines changed

test/higher_order_ops/test_with_effects.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -960,11 +960,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
960960
)
961961

962962
recorded_list.clear()
963-
# TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph
964-
# eagerly twice. Once for get_output_metadata and then once for
965-
# InvokeSubgraphAutogradOp. This causes record_memory to be called twice.
966-
with torch.no_grad():
967-
out2 = ep.module()(x)
963+
out2 = ep.module()(x)
968964
self.assertEqual(len(recorded_list), 4)
969965
self.assertTrue(torch.allclose(model(x)[0], out2[0]))
970966

torch/_higher_order_ops/invoke_subgraph.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,62 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
305305

306306

307307
def get_output_metadata(subgraph, *operands):
308+
"""
309+
Extract metadata about the subgraph outputs WITHOUT executing the subgraph.
310+
This avoids running side-effectful operations twice (once here, once in forward).
311+
We analyze the graph structure statically to extract metadata.
312+
"""
313+
# Unwrap FunctionalizeCtxWrapper if present
314+
if isinstance(subgraph, FunctionalizeCtxWrapper):
315+
subgraph = subgraph.subgraph
316+
317+
# If not a GraphModule, fall back to execution-based metadata extraction
318+
if not isinstance(subgraph, torch.fx.GraphModule):
319+
return _get_output_metadata_by_execution(subgraph, *operands)
320+
321+
output_metadata = OutputMetadata()
322+
323+
# Extract output arguments from the output node
324+
# The output node has args=(output_values,) where output_values is a tuple/list
325+
output_node = next(reversed(subgraph.graph.find_nodes(op="output")))
326+
output_metadata.num_fw_outs = len(output_node.args[0])
327+
328+
for idx, output_arg in enumerate(output_node.args[0]):
329+
if not isinstance(output_arg, torch.fx.Node):
330+
if isinstance(output_arg, int):
331+
output_metadata.indexes_with_symint.add(idx)
332+
output_metadata.indexes_with_no_grad.add(idx)
333+
continue
334+
335+
# Check node metadata for type information
336+
if output_arg.meta.get("val") is None:
337+
# If we don't have complete metadata for all outputs, fall back to execution
338+
# This is important for correctness (e.g., detecting SymInts) even though it
339+
# runs side-effectful operations
340+
return _get_output_metadata_by_execution(subgraph, *operands)
341+
342+
val = output_arg.meta["val"]
343+
if isinstance(val, torch.SymInt):
344+
output_metadata.indexes_with_symint.add(idx)
345+
output_metadata.indexes_with_no_grad.add(idx)
346+
elif isinstance(val, torch.Tensor):
347+
# Check if tensor requires grad from metadata
348+
if hasattr(val, "requires_grad") and not val.requires_grad:
349+
output_metadata.indexes_with_no_grad.add(idx)
350+
else:
351+
# Non-tensor, non-symint (shouldn't happen but be safe)
352+
output_metadata.indexes_with_no_grad.add(idx)
353+
354+
return output_metadata
355+
356+
357+
def _get_output_metadata_by_execution(subgraph, *operands):
358+
"""
359+
Fallback: Extract metadata by executing the subgraph.
360+
This should only be used when static analysis fails.
361+
WARNING: This will run side-effectful operations!
362+
"""
363+
308364
with suspend_functionalization(), disable_functional_mode():
309365
with disable_proxy_modes_tracing():
310366
# args are functional tensors, generate some example tensors
@@ -324,19 +380,15 @@ def get_output_metadata(subgraph, *operands):
324380

325381
num_fw_outs = len(fw_outs)
326382

327-
# Collect the indexes of none in the output to check that the grad
328-
# is None at the corresponding index in the backward. This check is
329-
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
330-
# Also collect the indexes of no_grad in the output to filter out
331-
# the grad_outs in the `backward` method.
332383
output_metadata = OutputMetadata()
333-
334384
output_metadata.num_fw_outs = num_fw_outs
385+
335386
for idx, fw_out in enumerate(fw_outs):
336387
if isinstance(fw_out, torch.SymInt):
337388
output_metadata.indexes_with_symint.add(idx)
338389
elif not fw_out.requires_grad:
339390
output_metadata.indexes_with_no_grad.add(idx)
391+
340392
return output_metadata
341393

342394

0 commit comments

Comments
 (0)