Skip to content

Commit 5cc498f

Browse files
committed
Make MoE models non-strict tracing friendly
1 parent 0187d5f commit 5cc498f

File tree

3 files changed

+440
-4
lines changed

3 files changed

+440
-4
lines changed

torchtitan/experiments/graph_trainer/make_fx_tracer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,52 @@
2323
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
2424

2525

26+
@contextmanager
27+
def _skip_nested_compile() -> Generator[None, None, None]:
28+
"""Tell dynamo to skip torch.compile calls encountered during make_fx tracing.
29+
30+
make_fx cannot trace through torch.compile'd functions (e.g. compiled
31+
flex_attention in FlexAttentionWrapper). Setting error_on_nested_fx_trace
32+
to False makes dynamo silently inline the wrapped function instead of
33+
raising, so make_fx traces the underlying ops normally.
34+
"""
35+
prev = torch._dynamo.config.error_on_nested_fx_trace
36+
torch._dynamo.config.error_on_nested_fx_trace = False
37+
try:
38+
yield
39+
finally:
40+
torch._dynamo.config.error_on_nested_fx_trace = prev
41+
42+
43+
@contextmanager
44+
def use_uncompiled_flex_attention() -> Generator[None, None, None]:
45+
"""Disable all torch.compile wrapping around flex_attention.
46+
47+
FlexAttentionWrapper uses a pre-compiled flex_attention (outer compile),
48+
and raw flex_attention itself calls torch.compile internally (inner
49+
compile). This context manager disables both so that flex_attention
50+
dispatches directly through the FlexAttentionHOP without any fusion.
51+
52+
Use this around both tracing and eager execution to ensure bitwise-
53+
identical numerics. The traced graph preserves the FlexAttentionHOP,
54+
and eager also dispatches through the same unfused HOP path.
55+
"""
56+
from torch.nn.attention import flex_attention as flex_attn_mod
57+
from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention
58+
59+
from torchtitan.models.common.attention import FlexAttentionWrapper
60+
61+
prev_compiled = FlexAttentionWrapper._compiled_flex_attn
62+
prev_debug = flex_attn_mod._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG
63+
FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention)
64+
flex_attn_mod._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
65+
try:
66+
yield
67+
finally:
68+
FlexAttentionWrapper._compiled_flex_attn = prev_compiled
69+
flex_attn_mod._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = prev_debug
70+
71+
2672
@dataclass
2773
class SubclassMeta:
2874
cls: type
@@ -315,7 +361,7 @@ def fn_with_subclass_handling(*plain_args):
315361
return unwrapped_outputs
316362

317363
# preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes
318-
with fake_mode, preserve_node_meta():
364+
with fake_mode, preserve_node_meta(), _skip_nested_compile(), use_uncompiled_flex_attention():
319365
traced = make_fx(
320366
fn_with_subclass_handling,
321367
record_stack_traces=True,

0 commit comments

Comments
 (0)