Skip to content
28 changes: 24 additions & 4 deletions torchtitan/experiments/graph_trainer/make_fx_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@
from torch.utils._python_dispatch import is_traceable_wrapper_subclass


@contextmanager
def _skip_nested_compile() -> Generator[None, None, None]:
"""Tell dynamo to skip torch.compile calls encountered during make_fx tracing.

make_fx cannot trace through torch.compile'd functions (e.g. compiled
flex_attention in FlexAttentionWrapper). Setting error_on_nested_fx_trace
to False makes dynamo silently inline the wrapped function instead of
raising, so make_fx traces the underlying ops normally.
"""
prev = torch._dynamo.config.error_on_nested_fx_trace
torch._dynamo.config.error_on_nested_fx_trace = False
try:
yield
finally:
torch._dynamo.config.error_on_nested_fx_trace = prev


@dataclass
class SubclassMeta:
cls: type
Expand Down Expand Up @@ -219,7 +236,7 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
seq_nr_to_fwd_node: dict[int, torch.fx.Node] = {}

for node in fx_g.graph.nodes:
if node.op != "call_function" or "seq_nr" not in node.meta:
if node.op not in ("call_function", "get_attr") or "seq_nr" not in node.meta:
continue
seq_nr = node.meta["seq_nr"]
if seq_nr not in seq_nr_to_fwd_node:
Expand Down Expand Up @@ -275,11 +292,14 @@ def functional_call(*all_args):
unwrapped_args.append(arg)
input_layouts.append(SubclassLayout(1, None))

fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
fake_mode = FakeTensorMode(
allow_non_fake_inputs=True,
shape_env=torch.fx.experimental.symbolic_shapes.ShapeEnv(),
)

def to_fake(t):
if isinstance(t, torch.Tensor):
return fake_mode.from_tensor(t)
return fake_mode.from_tensor(t, static_shapes=True)
return t

fake_args = tuple(to_fake(a) for a in unwrapped_args)
Expand Down Expand Up @@ -315,7 +335,7 @@ def fn_with_subclass_handling(*plain_args):
return unwrapped_outputs

# preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes
with fake_mode, preserve_node_meta():
with fake_mode, preserve_node_meta(), _skip_nested_compile():
traced = make_fx(
fn_with_subclass_handling,
record_stack_traces=True,
Expand Down
Loading
Loading