Skip to content

Implement torch compile and mxfp8 for flux#2579

Open
hlahkar wants to merge 5 commits intopytorch:mainfrom
hlahkar:flux_compile
Open

Implement torch compile and mxfp8 for flux#2579
hlahkar wants to merge 5 commits intopytorch:mainfrom
hlahkar:flux_compile

Conversation

@hlahkar
Copy link

@hlahkar hlahkar commented Mar 15, 2026

This PR implements torch compile and mxfp8 dtype computation for Flux model. Note: frozen encoder are not included in for mxfp8 quantization.

Test Report fo the test cases implementd for torch compile for Flux:

Running 15 items in this shard: tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component

tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls PASSED [ 6%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend PASSED [ 13%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks PASSED [ 20%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile PASSED [ 26%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile PASSED [ 33%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders PASSED [ 40%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers PASSED [ 46%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks PASSED [ 53%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled PASSED [ 60%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component PASSED [ 66%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component PASSED [ 73%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled PASSED [ 80%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled PASSED [ 86%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled PASSED [ 93%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component PASSED [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: 14 warnings
/usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: torch.jit.script_method is deprecated. Please switch to torch.compile or torch.export.
warnings.warn(

:488
:488: DeprecationWarning: builtin type SwigPyPacked has no module attribute

:488
:488: DeprecationWarning: builtin type SwigPyObject has no module attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== 15 passed, 16 warnings in 9.74s ===================================================================================================================

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 15, 2026
@wwwjn wwwjn self-assigned this Mar 16, 2026
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you also show the end to end MFU (MFU is implemented here) when compile is enabled?

When MXFP8 is enabled, can you show it's runnable as a sanity check?

For tests, an integration test with --compile enabled would be better than unit test. Can you help add a integration test instead of a unit test? Unit test is more for testing components.

training: TrainingConfig,
compile_config: CompileConfig,
):
if compile_config.enable and "model" in compile_config.components:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do compile.compononts = ["model", "encoder"], so encoder can be optional. And encoders are frozen during training, without fullgraph=True, these encoders can be separated from model?

I was wondering for smaller size model, will user always enable compile for encoder?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think if the benefit of compile encoders is small, then we can choose to not compile them at all.

I wouldn't recommend adding "encoder" to the components list before seeing real needs. (I mean, if we add encoder, we should at least rename model to decoder.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants