|
23 | 23 | from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
24 | 24 |
|
25 | 25 |
|
| 26 | +@contextmanager |
| 27 | +def _skip_nested_compile() -> Generator[None, None, None]: |
| 28 | + """Tell dynamo to skip torch.compile calls encountered during make_fx tracing. |
| 29 | +
|
| 30 | + make_fx cannot trace through torch.compile'd functions (e.g. compiled |
| 31 | + flex_attention in FlexAttentionWrapper). Setting error_on_nested_fx_trace |
| 32 | + to False makes dynamo silently inline the wrapped function instead of |
| 33 | + raising, so make_fx traces the underlying ops normally. |
| 34 | + """ |
| 35 | + prev = torch._dynamo.config.error_on_nested_fx_trace |
| 36 | + torch._dynamo.config.error_on_nested_fx_trace = False |
| 37 | + try: |
| 38 | + yield |
| 39 | + finally: |
| 40 | + torch._dynamo.config.error_on_nested_fx_trace = prev |
| 41 | + |
| 42 | + |
| 43 | +@contextmanager |
| 44 | +def use_uncompiled_flex_attention() -> Generator[None, None, None]: |
| 45 | + """Disable all torch.compile wrapping around flex_attention. |
| 46 | +
|
| 47 | + FlexAttentionWrapper uses a pre-compiled flex_attention (outer compile), |
| 48 | + and raw flex_attention itself calls torch.compile internally (inner |
| 49 | + compile). This context manager disables both so that flex_attention |
| 50 | + dispatches directly through the FlexAttentionHOP without any fusion. |
| 51 | +
|
| 52 | + Use this around both tracing and eager execution to ensure bitwise- |
| 53 | + identical numerics. The traced graph preserves the FlexAttentionHOP, |
| 54 | + and eager also dispatches through the same unfused HOP path. |
| 55 | + """ |
| 56 | + from torch.nn.attention import flex_attention as flex_attn_mod |
| 57 | + from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention |
| 58 | + |
| 59 | + from torchtitan.models.common.attention import FlexAttentionWrapper |
| 60 | + |
| 61 | + prev_compiled = FlexAttentionWrapper._compiled_flex_attn |
| 62 | + prev_debug = flex_attn_mod._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG |
| 63 | + FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention) |
| 64 | + flex_attn_mod._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True |
| 65 | + try: |
| 66 | + yield |
| 67 | + finally: |
| 68 | + FlexAttentionWrapper._compiled_flex_attn = prev_compiled |
| 69 | + flex_attn_mod._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = prev_debug |
| 70 | + |
| 71 | + |
26 | 72 | @dataclass |
27 | 73 | class SubclassMeta: |
28 | 74 | cls: type |
@@ -315,7 +361,7 @@ def fn_with_subclass_handling(*plain_args): |
315 | 361 | return unwrapped_outputs |
316 | 362 |
|
317 | 363 | # preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes |
318 | | - with fake_mode, preserve_node_meta(): |
| 364 | + with fake_mode, preserve_node_meta(), _skip_nested_compile(), use_uncompiled_flex_attention(): |
319 | 365 | traced = make_fx( |
320 | 366 | fn_with_subclass_handling, |
321 | 367 | record_stack_traces=True, |
|
0 commit comments