Skip to content

Commit 0a06f34

Browse files
committed
Merge remote-tracking branch 'origin/main' into moe_rewrite
2 parents 0e86cc9 + 34f705b commit 0a06f34

File tree

5 files changed

+593
-89
lines changed

5 files changed

+593
-89
lines changed

torchtitan/experiments/graph_trainer/make_fx_tracer.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@
2323
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
2424

2525

26+
@contextmanager
27+
def _skip_nested_compile() -> Generator[None, None, None]:
28+
"""Tell dynamo to skip torch.compile calls encountered during make_fx tracing.
29+
30+
make_fx cannot trace through torch.compile'd functions (e.g. compiled
31+
flex_attention in FlexAttentionWrapper). Setting error_on_nested_fx_trace
32+
to False makes dynamo silently inline the wrapped function instead of
33+
raising, so make_fx traces the underlying ops normally.
34+
"""
35+
prev = torch._dynamo.config.error_on_nested_fx_trace
36+
torch._dynamo.config.error_on_nested_fx_trace = False
37+
try:
38+
yield
39+
finally:
40+
torch._dynamo.config.error_on_nested_fx_trace = prev
41+
42+
2643
@dataclass
2744
class SubclassMeta:
2845
cls: type
@@ -219,7 +236,7 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
219236
seq_nr_to_fwd_node: dict[int, torch.fx.Node] = {}
220237

221238
for node in fx_g.graph.nodes:
222-
if node.op != "call_function" or "seq_nr" not in node.meta:
239+
if node.op not in ("call_function", "get_attr") or "seq_nr" not in node.meta:
223240
continue
224241
seq_nr = node.meta["seq_nr"]
225242
if seq_nr not in seq_nr_to_fwd_node:
@@ -275,11 +292,14 @@ def functional_call(*all_args):
275292
unwrapped_args.append(arg)
276293
input_layouts.append(SubclassLayout(1, None))
277294

278-
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
295+
fake_mode = FakeTensorMode(
296+
allow_non_fake_inputs=True,
297+
shape_env=torch.fx.experimental.symbolic_shapes.ShapeEnv(),
298+
)
279299

280300
def to_fake(t):
281301
if isinstance(t, torch.Tensor):
282-
return fake_mode.from_tensor(t)
302+
return fake_mode.from_tensor(t, static_shapes=True)
283303
return t
284304

285305
fake_args = tuple(to_fake(a) for a in unwrapped_args)
@@ -315,7 +335,7 @@ def fn_with_subclass_handling(*plain_args):
315335
return unwrapped_outputs
316336

317337
# preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes
318-
with fake_mode, preserve_node_meta():
338+
with fake_mode, preserve_node_meta(), _skip_nested_compile():
319339
traced = make_fx(
320340
fn_with_subclass_handling,
321341
record_stack_traces=True,

0 commit comments

Comments
 (0)