Make MoE models non-strict tracing friendly#2612
Conversation
| x.bfloat16(), mlp1_weight.transpose(-2, -1).bfloat16(), offs=offsets | ||
| ) | ||
|
|
||
| b1 = torch.cat([mlp1_bias, mlp1_bias.new_zeros(1, mlp1_bias.shape[-1])]) |
There was a problem hiding this comment.
This is pytorch native op, I'm surprised that we still need any special treatment, even with the padding kernel removed?
There was a problem hiding this comment.
Yeah, you're right, this is not needed. Probably some left over when doing exploration.
411b7f8 to
a1b1229
Compare
6adec7b to
b516c04
Compare
torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Outdated
Show resolved
Hide resolved
torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Outdated
Show resolved
Hide resolved
| FlexAttentionWrapper uses torch.compile'd flex_attention by default. | ||
| torch.compile inside make_fx tracing is not supported and raises: | ||
| "Detected that you are using FX to symbolically trace a | ||
| dynamo-optimized function." | ||
| Using the raw version lets make_fx decompose flex_attention into | ||
| plain aten ops (bmm, softmax, etc.) which trace correctly. |
There was a problem hiding this comment.
IIUC, in eager, compiled flex_attn show up as a HOP?
We should fix make_fx to trace over HOP?
Also, in Ed's 2-tier graph setup, we are already non-strict tracing per layer, which is torch.compiled.
There was a problem hiding this comment.
Ok, the real problem is FlexAttentionWrapper is using a different torch.compile wrapper than the original flex_attention.
The graph we traced consists of the flex_attention higher order op. The eager mode runs the torch.compiled version which produces numerical differences with running the hop. So we'll run eager under use_uncompiled_flex_attention, to ensure both graph and eager's flex_attention are the un-compiled version and in this way we get bitwise equivalence.
There was a problem hiding this comment.
Summarizing the current status: SPDA is used in llama3, qwen3, it gets decomposed with pre_dispatch=False. For ds, lma4, gpt_oss, flex_attention is used and the flex_attention hop is preserved in graph. bit-wise equivalence is achieved via running eager under use_uncompiled_flex_attention.
torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Outdated
Show resolved
Hide resolved
SherlockNoMad
left a comment
There was a problem hiding this comment.
make_fx should preserve torch.compiled-inner region.
If you can't get it working in this PR, let's do it as a follow up.
torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Outdated
Show resolved
Hide resolved
torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Outdated
Show resolved
Hide resolved
torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Outdated
Show resolved
Hide resolved
|
cc @danielvegamyhre you might need this change in torchao in the future, if we want to make permutation work with non-strict tracing. |
6b8562d to
d76b181
Compare
5cc498f to
63153f5
Compare
- Add use_uncompiled_flex_attention context manager to disable both outer (FlexAttentionWrapper) and inner (flex_attention) torch.compile so eager and traced graph use the same unfused FlexAttentionHOP path. - Wrap eager reference forwards with use_uncompiled_flex_attention for bitwise-identical numerics between eager and traced graph replay. - Add test_flex_attention_annotations to verify compile_with_inductor and ac_region_id annotations survive on FlexAttentionHOP nodes.
…ed graph Both ref and test models now run through the same traced graph via run_traced_module, ensuring bitwise-identical results regardless of attention backend. This eliminates the need for use_uncompiled_flex_attention since both sides use the same FlexAttentionHOP dispatch path.
| data_parallel(model_ref, device_mesh=fsdp_mesh, mode="fully_shard") | ||
| data_parallel(model_test, device_mesh=fsdp_mesh, mode="fully_shard") |
There was a problem hiding this comment.
should be
model_ref = data_parallel(model_ref, ...)
model_test = data_parallel(model_test, ...)
…ugh traced graph" This reverts commit 6870183.
Replace use_uncompiled_flex_attention workaround with regional_inductor compilation for flex attention models. Both eager and traced paths now use the same compiled triton kernels, producing bitwise-identical results. Key changes: - _copy_fwd_metadata_to_bw_nodes: propagate annotations to get_attr nodes (not just call_function) so regional_inductor's partitioner includes HOP subgraph modules in the same partition - FakeTensorMode in trace_module: add ShapeEnv with static_shapes=True so standalone_compile inside regional_inductor works correctly - FlexAttentionWrapper.inductor_configs: extract inductor options as a class variable so annotation and torch.compile use the same configs - annotate_flex_attention_for_regional_inductor: context manager that annotates FlexAttentionWrapper.forward for regional_inductor compilation
Summary
kernels.py: wrapfill_indicesascustom_op+register_fakefor FakeTensor compatibility duringmake_fxtracingtest_trace_module.py: add model-specific tests (llama3, qwen3, qwen3_moe, deepseek_v3, llama4, gpt_oss) and FSDP testsStacked on #2553
Test plan
pytest torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k TestTraceModels— all 6 model tests passpre-commit run --all-filespasses