[RL] Changes to enable compilation for trainer#2568
[RL] Changes to enable compilation for trainer#2568Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Conversation
There was a problem hiding this comment.
if we go with pytorch varlen, do we still need to worry about this file? cc @wwwjn
There was a problem hiding this comment.
This is post rebase - yes we do need these changes for the following reasons:
- Moving
torch.autograd.Functionout of the inner context - this is so compile can trace it _flash_attn_varlen_fwdcustom op - this is so torch can trace it; without custom op, we don't know how shape will propogate through this
There was a problem hiding this comment.
if ur going with pytorch varlen i think u can directly call varlen_attn (which is already a custom op) pytorch's varlen calls the upstream flash attention impl instead of vllm's flash attention
There was a problem hiding this comment.
Sorry I misunderstood this question (I thought this was related to the deletion of vllm_attention directory).
We will need to move the autograd.Function out but yes we won't need the custom op in that case
8a62589 to
be75f04
Compare
be75f04 to
520d314
Compare
b0e8401 to
d032ea8
Compare
d032ea8 to
76c42cc
Compare
| from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func | ||
|
|
||
|
|
||
| @torch.library.custom_op("rl::flash_attn_varlen_fwd", mutates_args=()) |
There was a problem hiding this comment.
let's not merge this, let use pytorch varlen attention in #2364
cc @zhxchen17 to unblock
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def parallelize_qwen3( |
There was a problem hiding this comment.
@wwwjn could you work with @sanketpurandare to land PP + DTensor so that we can remove this ad hoc function?
| return model | ||
|
|
||
|
|
||
| def apply_compile(model: nn.Module, compile_config: CompileConfig): |
Summary
In this PR, we enable naive, JIT style torch.compile for the RL policy trainer. This is the first step towards speeding up the trainer model. Changes are:
layer -> This is crticial, as
torch.compile()results in weight name change which breaks the weight transfercompile=TrainerCompileConfig(enable=True, backend="aot_eager")
can't be traced by the compiler) into a module-level FlashAttnVarlenFunction
with a fake implementation, so AOT Autograd can trace through it with FakeTensors
Test Plan
Results in the same losses as on main - the timing is now like:
Main
Changes
So we save ~60s of runtime e2e via compilation in this manner without affecting our logits/accuracy