Implement torch compile and mxfp8 for flux#2579
Implement torch compile and mxfp8 for flux#2579hlahkar wants to merge 5 commits intopytorch:mainfrom
Conversation
wwwjn
left a comment
There was a problem hiding this comment.
Thanks, can you also show the end to end MFU (MFU is implemented here) when compile is enabled?
When MXFP8 is enabled, can you show it's runnable as a sanity check?
For tests, an integration test with --compile enabled would be better than unit test. Can you help add a integration test instead of a unit test? Unit test is more for testing components.
| training: TrainingConfig, | ||
| compile_config: CompileConfig, | ||
| ): | ||
| if compile_config.enable and "model" in compile_config.components: |
There was a problem hiding this comment.
Should we do compile.compononts = ["model", "encoder"], so encoder can be optional. And encoders are frozen during training, without fullgraph=True, these encoders can be separated from model?
I was wondering for smaller size model, will user always enable compile for encoder?
There was a problem hiding this comment.
Good point. I think if the benefit of compile encoders is small, then we can choose to not compile them at all.
I wouldn't recommend adding "encoder" to the components list before seeing real needs. (I mean, if we add encoder, we should at least rename model to decoder.)
This PR implements torch compile and mxfp8 dtype computation for Flux model. Note: frozen encoder are not included in for mxfp8 quantization.
Test Report fo the test cases implementd for torch compile for Flux:
Running 15 items in this shard: tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls PASSED [ 6%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend PASSED [ 13%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks PASSED [ 20%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile PASSED [ 26%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile PASSED [ 33%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders PASSED [ 40%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers PASSED [ 46%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks PASSED [ 53%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled PASSED [ 60%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component PASSED [ 66%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component PASSED [ 73%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled PASSED [ 80%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled PASSED [ 86%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled PASSED [ 93%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component PASSED [100%]
========================================================================================================================== warnings summary ==========================================================================================================================
../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: 14 warnings
/usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning:
torch.jit.script_methodis deprecated. Please switch totorch.compileortorch.export.warnings.warn(
:488
:488: DeprecationWarning: builtin type SwigPyPacked has no module attribute
:488
:488: DeprecationWarning: builtin type SwigPyObject has no module attribute
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== 15 passed, 16 warnings in 9.74s ===================================================================================================================