Skip to content

[RL] Changes to enable compilation for trainer#2568

Open
Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/enable_trainer_compile_03_10
Open

[RL] Changes to enable compilation for trainer#2568
Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/enable_trainer_compile_03_10

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Mar 13, 2026

Summary

In this PR, we enable naive, JIT style torch.compile for the RL policy trainer. This is the first step towards speeding up the trainer model. Changes are:

  1. Wiring through compilation config:
  • Added TrainerCompileConfig dataclass with enable (bool) and backend (str, default "eager") fields
  • Added compile field to PolicyTrainer.Config with compile and aot_eager backend
  • Added apply_compile() method that calls .compile(backend=..., fullgraph=True) on each transformer
    layer -> This is crticial, as torch.compile() results in weight name change which breaks the weight transfer
  1. config_registry.py — Enable compile by default in configs
  • Both rl_grpo_qwen3_0_6b and rl_grpo_qwen3_debug configs now set
    compile=TrainerCompileConfig(enable=True, backend="aot_eager")
  1. vllm_compat/models/attention.py — Make attention operation compile-compatible
  • Moved the FlashAttnWithBackward autograd function out of the forward() method (nested classes
    can't be traced by the compiler) into a module-level FlashAttnVarlenFunction
  • Registered the flash-attention forward as a torch.library.custom_op (rl::flash_attn_varlen_fwd)
    with a fake implementation, so AOT Autograd can trace through it with FakeTensors
  • Simplified the call site in VLLMCompatibleFlashAttention.forward() to use the new function

Test Plan

python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B

Results in the same losses as on main - the timing is now like:

Main

[actor=<root>] Step  9 | Loss: +0.0031 | Reward: +0.450 | Correct: 29/40 | Avg tokens: 100 | Logprob diff: mean=-1.1882e-01, max=1.0228e+01 | Time: 21.4s
[actor=<root>]   Step Timing | Generator: 2.3s | Trainer: 15.3s | WeightSync: 3.8s
[actor=<root>] Cumulative Timing | Generator: 21.3s | Trainer: 150.2s | WeightSync: 33.1s | Total: 204.6s
[actor=<root>] RL Training complete
[actor=<root>] Evaluating post-training performance...
[actor=<root>] Eval: Accuracy=55% (11/20) Format=95% (19/20)
[actor=<root>] ================================================================================
[actor=<root>] Pre-training:  Accuracy=40% (8/20) Format=50% (10/20)
[actor=<root>] Post-training: Accuracy=55% (11/20) Format=95% (19/20)
[actor=<root>] ================================================================================```

Changes

[actor=<root>] Step  9 | Loss: +0.0031 | Reward: +0.450 | Correct: 29/40 | Avg tokens: 100 | Logprob diff: mean=-1.1882e-01, max=1.0228e+01 | Time: 13.9s
[actor=<root>]   Step Timing | Generator: 2.4s | Trainer: 6.5s | WeightSync: 5.0s
[actor=<root>] Cumulative Timing | Generator: 21.7s | Trainer: 88.5s | WeightSync: 30.6s | Total: 140.9s
[actor=<root>] RL Training complete
[actor=<root>] Evaluating post-training performance...
[actor=<root>] Eval: Accuracy=55% (11/20) Format=95% (19/20)
[actor=<root>] ================================================================================
[actor=<root>] Pre-training:  Accuracy=40% (8/20) Format=50% (10/20)
[actor=<root>] Post-training: Accuracy=55% (11/20) Format=95% (19/20)
[actor=<root>] ================================================================================
(vllm) [lucaskabela@devgpu007.eag6 ~/torchtitan (lucaskabela/enable_trainer_compile_03_10)]$ pytho

So we save ~60s of runtime e2e via compilation in this manner without affecting our logits/accuracy

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 13, 2026
@Lucaskabela Lucaskabela marked this pull request as ready for review March 13, 2026 19:19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we go with pytorch varlen, do we still need to worry about this file? cc @wwwjn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is post rebase - yes we do need these changes for the following reasons:

  1. Moving torch.autograd.Function out of the inner context - this is so compile can trace it
  2. _flash_attn_varlen_fwd custom op - this is so torch can trace it; without custom op, we don't know how shape will propogate through this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ur going with pytorch varlen i think u can directly call varlen_attn (which is already a custom op) pytorch's varlen calls the upstream flash attention impl instead of vllm's flash attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I misunderstood this question (I thought this was related to the deletion of vllm_attention directory).

We will need to move the autograd.Function out but yes we won't need the custom op in that case

@Lucaskabela Lucaskabela requested a review from wwwjn March 13, 2026 20:39
@Lucaskabela Lucaskabela force-pushed the lucaskabela/enable_trainer_compile_03_10 branch from 8a62589 to be75f04 Compare March 13, 2026 22:14
@Lucaskabela Lucaskabela force-pushed the lucaskabela/enable_trainer_compile_03_10 branch from be75f04 to 520d314 Compare March 13, 2026 23:50
@Lucaskabela Lucaskabela requested a review from tianyu-l March 14, 2026 00:28
@Lucaskabela Lucaskabela marked this pull request as draft March 17, 2026 17:34
@Lucaskabela Lucaskabela force-pushed the lucaskabela/enable_trainer_compile_03_10 branch 4 times, most recently from b0e8401 to d032ea8 Compare March 19, 2026 21:02
@Lucaskabela Lucaskabela marked this pull request as ready for review March 19, 2026 21:06
@Lucaskabela Lucaskabela force-pushed the lucaskabela/enable_trainer_compile_03_10 branch from d032ea8 to 76c42cc Compare March 19, 2026 21:42
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func


@torch.library.custom_op("rl::flash_attn_varlen_fwd", mutates_args=())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not merge this, let use pytorch varlen attention in #2364

cc @zhxchen17 to unblock

logger = logging.getLogger(__name__)


def parallelize_qwen3(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn could you work with @sanketpurandare to land PP + DTensor so that we can remove this ad hoc function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me follow up with @sanketpurandare

return model


def apply_compile(model: nn.Module, compile_config: CompileConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we switched to a central impl in #2615, but I understand that this is doing model.compile() not torch.compile(model)

@fegin can we universally switch to model.compile style (and it should still work with DCP)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RL][Feature Request] Tun on torch.compile + cudagraphs for trainer definition

4 participants