Skip to content

Commit b0e8401

Browse files
committed
Incorporate feedback on reusing config
1 parent 9143615 commit b0e8401

File tree

6 files changed

+61
-327
lines changed

6 files changed

+61
-327
lines changed

torchtitan/experiments/rl/actors/generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,12 @@ async def pull_model_state_dict(self, version: int) -> None:
318318
Args:
319319
version: New policy version number.
320320
"""
321-
from monarch.rdma import is_rdma_available
322-
323321
model_sd = self._get_model().model.state_dict()
324322
await ts.get_state_dict(
325323
"model_state_dict",
326324
user_state_dict=model_sd,
327325
strict=False,
328-
direct_rdma=is_rdma_available(),
326+
direct_rdma=False,
329327
)
330328
self.policy_version = version
331329
logger.debug(

torchtitan/experiments/rl/actors/trainer.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchtitan.components.lr_scheduler import LRSchedulersContainer
2222
from torchtitan.components.optimizer import OptimizersContainer
2323
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
24-
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
24+
from torchtitan.config.configs import CompileConfig, ParallelismConfig, TrainingConfig
2525
from torchtitan.distributed import ParallelDims, utils as dist_utils
2626
from torchtitan.experiments.rl.actors.utils import (
2727
compute_policy_gradient_loss,
@@ -38,16 +38,6 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41-
@dataclass(kw_only=True, slots=True)
42-
class TrainerCompileConfig:
43-
"""Compilation settings for the PolicyTrainer."""
44-
45-
enable: bool = False
46-
"""Enable per-layer torch.compile on the training model."""
47-
backend: str = "eager"
48-
"""torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor')."""
49-
50-
5141
class PolicyTrainer(Actor, Configurable):
5242
"""
5343
Updates policy based on collected Episode using TorchTitan components.
@@ -74,7 +64,7 @@ class Config(Configurable.Config):
7464
parallelism: ParallelismConfig = field(default_factory=ParallelismConfig)
7565
comm: CommConfig = field(default_factory=CommConfig)
7666
"""Communication configuration for distributed initialization."""
77-
compile: TrainerCompileConfig = field(default_factory=TrainerCompileConfig)
67+
compile: CompileConfig = field(default_factory=CompileConfig)
7868

7969
def __init__(
8070
self,
@@ -85,6 +75,11 @@ def __init__(
8575
hf_assets_path: str = "",
8676
transfer_dtype: str = "",
8777
):
78+
# Silence noisy torchstore per-tensor transport logs in actor subprocess
79+
logging.getLogger("torchstore.transport.shared_memory").setLevel(
80+
logging.WARNING
81+
)
82+
8883
self.config = config
8984
self.model_spec = model_spec
9085
# Only cast if transfer dtype differs from training dtype, otherwise
@@ -126,8 +121,6 @@ def __init__(
126121
model_spec, config, device_type, batch_invariant_mode, hf_assets_path
127122
)
128123
model.train()
129-
if config.compile.enable:
130-
model = self._compile_model(model, config.compile.backend)
131124
self.model = model
132125
self.model_parts = [model]
133126

@@ -231,6 +224,7 @@ def _build_model(
231224
model,
232225
parallel_dims=self.parallel_dims,
233226
parallelism=config.parallelism,
227+
compile_config=config.compile,
234228
)
235229

236230
model.to_empty(device=device_type)
@@ -242,20 +236,6 @@ def _build_model(
242236

243237
return model
244238

245-
def _compile_model(self, model: torch.nn.Module, backend: str) -> torch.nn.Module:
246-
"""Compile each transformer layer with torch.compile.
247-
248-
Args:
249-
model: The model whose layers will be compiled.
250-
backend: torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor').
251-
"""
252-
for layer_id in model.layers:
253-
model.layers[layer_id].compile(backend=backend, fullgraph=True)
254-
logger.info(
255-
f"Compiled {len(model.layers)} transformer layers with {backend} backend"
256-
)
257-
return model
258-
259239
@endpoint
260240
async def push_model_state_dict(self) -> None:
261241
"""Publish model weights for generator consumption via TorchStore.
@@ -271,12 +251,10 @@ async def push_model_state_dict(self) -> None:
271251
means "skip StorageVolumes and let the destination read directly
272252
from the source's GPU memory".
273253
"""
274-
from monarch.rdma import is_rdma_available
275-
276254
await ts.put_state_dict(
277255
self.model.state_dict(),
278256
"model_state_dict",
279-
direct_rdma=is_rdma_available(),
257+
direct_rdma=False,
280258
transfer_dtype=self._transfer_dtype,
281259
)
282260

torchtitan/experiments/rl/config_registry.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313

1414
from torchtitan.components.lr_scheduler import LRSchedulersContainer
1515
from torchtitan.components.optimizer import OptimizersContainer
16-
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
16+
from torchtitan.config.configs import CompileConfig, ParallelismConfig, TrainingConfig
1717
from torchtitan.experiments.rl.actors.generator import (
1818
GeneratorCompileConfig,
1919
SamplingConfig,
2020
VLLMGenerator,
2121
)
22-
from torchtitan.experiments.rl.actors.trainer import (
23-
PolicyTrainer,
24-
TrainerCompileConfig,
25-
)
22+
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
2623
from torchtitan.experiments.rl.simple_grpo_sum_digits import RLTrainer
2724
from torchtitan.models.qwen3 import model_registry
2825

@@ -44,7 +41,7 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config:
4441
parallelism=ParallelismConfig(
4542
tensor_parallel_degree=2,
4643
),
47-
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
44+
compile=CompileConfig(enable=True, backend="aot_eager"),
4845
),
4946
generator=VLLMGenerator.Config(
5047
model_dtype="bfloat16",
@@ -84,7 +81,7 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
8481
parallelism=ParallelismConfig(
8582
tensor_parallel_degree=2,
8683
),
87-
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
84+
compile=CompileConfig(enable=True, backend="aot_eager"),
8885
),
8986
generator=VLLMGenerator.Config(
9087
model_dtype="bfloat16",
@@ -124,7 +121,7 @@ def rl_grpo_qwen3_debug() -> RLTrainer.Config:
124121
tensor_parallel_degree=1,
125122
data_parallel_replicate_degree=1,
126123
),
127-
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
124+
compile=CompileConfig(enable=True, backend="aot_eager"),
128125
),
129126
generator=VLLMGenerator.Config(
130127
compile=GeneratorCompileConfig(

torchtitan/experiments/rl/models/parallelize.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525

2626
from torchtitan.config import ParallelismConfig
27+
from torchtitan.config.configs import CompileConfig
2728
from torchtitan.distributed import ParallelDims
2829

2930
logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ def parallelize_qwen3(
3435
*,
3536
parallel_dims: ParallelDims,
3637
parallelism: ParallelismConfig,
38+
compile_config: CompileConfig | None = None,
3739
has_position_id: bool = False,
3840
):
3941
"""
@@ -44,6 +46,8 @@ def parallelize_qwen3(
4446
TODO: Change to core torchtitan's Qwen3 parallel plan when full DTensor is ready
4547
4648
Args:
49+
compile_config: If provided and enabled, applies per-layer torch.compile
50+
after TP (matching the pattern in torchtitan/models/llama3/parallelize.py).
4751
has_position_id: Whether position IDs are passed as an explicit argument
4852
to the attention module. True for vLLM inference (generator),
4953
False for training (trainer).
@@ -60,9 +64,31 @@ def parallelize_qwen3(
6064
has_position_id=has_position_id,
6165
)
6266

67+
if (
68+
compile_config is not None
69+
and compile_config.enable
70+
and "model" in compile_config.components
71+
):
72+
apply_compile(model, compile_config)
73+
6374
return model
6475

6576

77+
def apply_compile(model: nn.Module, compile_config: CompileConfig):
78+
"""Apply torch.compile to each TransformerBlock.
79+
80+
Follows the same pattern as torchtitan/models/llama3/parallelize.py.
81+
"""
82+
# NOTE: we MUST use `.compile()` instead of `torch.compile()` here, because
83+
# compatibility with weight naming between this model and vLLM definition.
84+
# `.compile()` modifies the module in-place and returns None, so we must
85+
# NOT reassign or re-register the module.
86+
for transformer_block in model.layers.values():
87+
transformer_block.compile(backend=compile_config.backend, fullgraph=True)
88+
89+
logger.info("Compiling each TransformerBlock with torch.compile")
90+
91+
6692
def apply_non_moe_tp(
6793
model: nn.Module,
6894
tp_mesh: DeviceMesh,

torchtitan/experiments/rl/models/vllm_compat_attention.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,6 @@
1212
from torch.distributed._tensor import DTensor
1313

1414
from torchtitan.protocols.module import Module
15-
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
16-
17-
18-
# ---------------------------------------------------------------------------
19-
# Custom op wrapping vLLM's flash-attention varlen forward.
20-
#
21-
# Registering as a ``torch.library.custom_op`` with a fake implementation
22-
# and explicit autograd lets AOT Autograd trace through the op with
23-
# FakeTensors (required by the compiler_toolkit's joint-graph export path)
24-
# and correctly capture the backward graph.
25-
# ---------------------------------------------------------------------------
2615

2716

2817
@torch.library.custom_op("rl::flash_attn_varlen_fwd", mutates_args=())
@@ -34,7 +23,6 @@ def _flash_attn_varlen_fwd(
3423
seq_len: int,
3524
scale: float,
3625
num_splits: int,
37-
enable_gqa: bool,
3826
) -> torch.Tensor:
3927
from vllm.v1.attention.backends.fa_utils import (
4028
flash_attn_varlen_func as _flash_fn,
@@ -69,38 +57,37 @@ def _flash_attn_varlen_fwd_fake(
6957
seq_len: int,
7058
scale: float,
7159
num_splits: int,
72-
enable_gqa: bool,
7360
) -> torch.Tensor:
74-
# Output shape matches Q: (total_tokens, num_heads, head_dim)
7561
return torch.empty(
7662
(q.shape[0], q.shape[1], q.shape[2]), dtype=q.dtype, device=q.device
7763
)
7864

7965

80-
class FlashAttnVarlenFunction(torch.autograd.Function):
81-
"""autograd.Function wrapping the vLLM flash-attention custom op.
82-
83-
The forward calls the ``rl::flash_attn_varlen_fwd`` custom op (which
84-
has a registered fake implementation for torch.compile tracing).
85-
The backward is a manual PyTorch attention recompute.
86-
"""
87-
88-
@staticmethod
89-
def forward(q, k, v, cu_seqlens, seq_len, scale, num_splits, enable_gqa):
90-
return _flash_attn_varlen_fwd(
91-
q, k, v, cu_seqlens, seq_len, scale, num_splits, enable_gqa
92-
)
93-
66+
class FlashAttnWithBackward(torch.autograd.Function):
9467
@staticmethod
95-
def setup_context(ctx, inputs, output):
96-
q, k, v, cu_seqlens, seq_len, scale, num_splits, enable_gqa = inputs
68+
def forward(
69+
ctx: torch.autograd.function.FunctionCtx,
70+
q: torch.Tensor,
71+
k: torch.Tensor,
72+
v: torch.Tensor,
73+
cu_seqlens: torch.Tensor,
74+
seq_len: int,
75+
scale: float,
76+
num_splits: int,
77+
enable_gqa: bool,
78+
) -> torch.Tensor:
79+
output = _flash_attn_varlen_fwd(q, k, v, cu_seqlens, seq_len, scale, num_splits)
80+
# Save for backward
9781
ctx.save_for_backward(q, k, v, output)
9882
ctx.scale = scale
9983
ctx.seq_len = seq_len
10084
ctx.enable_gqa = enable_gqa
85+
return output
10186

10287
@staticmethod
103-
def backward(ctx, grad_output):
88+
def backward(
89+
ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
90+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None, None,]:
10491
q, k, v, output = ctx.saved_tensors
10592
scale = ctx.scale
10693
seq_len = ctx.seq_len
@@ -207,12 +194,9 @@ class VLLMCompatibleFlashAttention(Module):
207194

208195
def __init__(self) -> None:
209196
super().__init__()
210-
self.flash_attn_varlen_func = flash_attn_varlen_func
211197
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
212-
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
213198

214199
self.vllm_is_batch_invariant = vllm_is_batch_invariant
215-
self.fa_version = get_flash_attn_version()
216200

217201
def forward(
218202
self,
@@ -267,8 +251,8 @@ def forward(
267251
if scale is None:
268252
scale = 1.0 / math.sqrt(q.size(-1))
269253

270-
# Call flash attention via autograd.Function (which wraps the custom op)
271-
output_varlen = FlashAttnVarlenFunction.apply(
254+
# Call Flash Attention varlen with custom backward
255+
output_varlen = FlashAttnWithBackward.apply(
272256
q_varlen,
273257
k_varlen,
274258
v_varlen,

0 commit comments

Comments
 (0)