Skip to content

Make MoE models non-strict tracing friendly#2612

Merged
yiming0416 merged 8 commits intopytorch:mainfrom
ydwu4:pr2-moe-tracing
Mar 19, 2026
Merged

Make MoE models non-strict tracing friendly#2612
yiming0416 merged 8 commits intopytorch:mainfrom
ydwu4:pr2-moe-tracing

Conversation

@ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Mar 16, 2026

Summary

  • kernels.py: wrap fill_indices as custom_op + register_fake for FakeTensor compatibility during make_fx tracing
  • test_trace_module.py: add model-specific tests (llama3, qwen3, qwen3_moe, deepseek_v3, llama4, gpt_oss) and FSDP tests

Stacked on #2553

Test plan

  • pytest torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k TestTraceModels — all 6 model tests pass
  • pre-commit run --all-files passes

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 16, 2026
x.bfloat16(), mlp1_weight.transpose(-2, -1).bfloat16(), offs=offsets
)

b1 = torch.cat([mlp1_bias, mlp1_bias.new_zeros(1, mlp1_bias.shape[-1])])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pytorch native op, I'm surprised that we still need any special treatment, even with the padding kernel removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right, this is not needed. Probably some left over when doing exploration.

@ydwu4 ydwu4 force-pushed the pr2-moe-tracing branch 3 times, most recently from 411b7f8 to a1b1229 Compare March 17, 2026 01:41
@ydwu4 ydwu4 force-pushed the pr2-moe-tracing branch 9 times, most recently from 6adec7b to b516c04 Compare March 18, 2026 00:18
@ydwu4 ydwu4 requested review from tianyu-l and yiming0416 March 18, 2026 00:20
Comment on lines +360 to +365
FlexAttentionWrapper uses torch.compile'd flex_attention by default.
torch.compile inside make_fx tracing is not supported and raises:
"Detected that you are using FX to symbolically trace a
dynamo-optimized function."
Using the raw version lets make_fx decompose flex_attention into
plain aten ops (bmm, softmax, etc.) which trace correctly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, in eager, compiled flex_attn show up as a HOP?

We should fix make_fx to trace over HOP?

Also, in Ed's 2-tier graph setup, we are already non-strict tracing per layer, which is torch.compiled.

Copy link
Contributor Author

@ydwu4 ydwu4 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, the real problem is FlexAttentionWrapper is using a different torch.compile wrapper than the original flex_attention.

The graph we traced consists of the flex_attention higher order op. The eager mode runs the torch.compiled version which produces numerical differences with running the hop. So we'll run eager under use_uncompiled_flex_attention, to ensure both graph and eager's flex_attention are the un-compiled version and in this way we get bitwise equivalence.

Copy link
Contributor Author

@ydwu4 ydwu4 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing the current status: SPDA is used in llama3, qwen3, it gets decomposed with pre_dispatch=False. For ds, lma4, gpt_oss, flex_attention is used and the flex_attention hop is preserved in graph. bit-wise equivalence is achieved via running eager under use_uncompiled_flex_attention.

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_fx should preserve torch.compiled-inner region.

If you can't get it working in this PR, let's do it as a follow up.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock

@tianyu-l
Copy link
Contributor

cc @danielvegamyhre you might need this change in torchao in the future, if we want to make permutation work with non-strict tracing.

@ydwu4 ydwu4 force-pushed the pr2-moe-tracing branch 2 times, most recently from 6b8562d to d76b181 Compare March 18, 2026 05:22
@ydwu4 ydwu4 force-pushed the pr2-moe-tracing branch 4 times, most recently from 5cc498f to 63153f5 Compare March 18, 2026 06:17
ydwu4 added 2 commits March 18, 2026 11:44
- Add use_uncompiled_flex_attention context manager to disable both
  outer (FlexAttentionWrapper) and inner (flex_attention) torch.compile
  so eager and traced graph use the same unfused FlexAttentionHOP path.
- Wrap eager reference forwards with use_uncompiled_flex_attention for
  bitwise-identical numerics between eager and traced graph replay.
- Add test_flex_attention_annotations to verify compile_with_inductor
  and ac_region_id annotations survive on FlexAttentionHOP nodes.
…ed graph

Both ref and test models now run through the same traced graph via
run_traced_module, ensuring bitwise-identical results regardless of
attention backend. This eliminates the need for use_uncompiled_flex_attention
since both sides use the same FlexAttentionHOP dispatch path.
@pytorch-bot pytorch-bot bot removed the ciflow/8gpu label Mar 18, 2026
Comment on lines +646 to +647
data_parallel(model_ref, device_mesh=fsdp_mesh, mode="fully_shard")
data_parallel(model_test, device_mesh=fsdp_mesh, mode="fully_shard")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be
model_ref = data_parallel(model_ref, ...)
model_test = data_parallel(model_test, ...)

ydwu4 added 3 commits March 18, 2026 14:45
Replace use_uncompiled_flex_attention workaround with regional_inductor
compilation for flex attention models. Both eager and traced paths now
use the same compiled triton kernels, producing bitwise-identical results.

Key changes:
- _copy_fwd_metadata_to_bw_nodes: propagate annotations to get_attr nodes
  (not just call_function) so regional_inductor's partitioner includes HOP
  subgraph modules in the same partition
- FakeTensorMode in trace_module: add ShapeEnv with static_shapes=True so
  standalone_compile inside regional_inductor works correctly
- FlexAttentionWrapper.inductor_configs: extract inductor options as a class
  variable so annotation and torch.compile use the same configs
- annotate_flex_attention_for_regional_inductor: context manager that
  annotates FlexAttentionWrapper.forward for regional_inductor compilation
@pytorch-bot pytorch-bot bot removed the ciflow/8gpu label Mar 19, 2026
@yiming0416 yiming0416 merged commit c9a68de into pytorch:main Mar 19, 2026
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants