Skip to content

Commit 520d314

Browse files
committed
Incorporate feedback on reusing config
1 parent 42d870c commit 520d314

File tree

3 files changed

+34
-36
lines changed

3 files changed

+34
-36
lines changed

torchtitan/experiments/rl/unified/actors/trainer.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchtitan.components.lr_scheduler import LRSchedulersContainer
2222
from torchtitan.components.optimizer import OptimizersContainer
2323
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
24-
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
24+
from torchtitan.config.configs import CompileConfig, ParallelismConfig, TrainingConfig
2525
from torchtitan.distributed import ParallelDims, utils as dist_utils
2626
from torchtitan.experiments.rl.unified.actors.utils import (
2727
compute_policy_gradient_loss,
@@ -38,16 +38,6 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41-
@dataclass(kw_only=True, slots=True)
42-
class TrainerCompileConfig:
43-
"""Compilation settings for the PolicyTrainer."""
44-
45-
enable: bool = False
46-
"""Enable per-layer torch.compile on the training model."""
47-
backend: str = "eager"
48-
"""torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor')."""
49-
50-
5141
class PolicyTrainer(Actor, Configurable):
5242
"""
5343
Updates policy based on collected Episode using TorchTitan components.
@@ -74,7 +64,7 @@ class Config(Configurable.Config):
7464
parallelism: ParallelismConfig = field(default_factory=ParallelismConfig)
7565
comm: CommConfig = field(default_factory=CommConfig)
7666
"""Communication configuration for distributed initialization."""
77-
compile: TrainerCompileConfig = field(default_factory=TrainerCompileConfig)
67+
compile: CompileConfig = field(default_factory=CompileConfig)
7868

7969
def __init__(
8070
self,
@@ -120,8 +110,6 @@ def __init__(
120110
model_spec, config, device_type, batch_invariant_mode, hf_assets_path
121111
)
122112
model.train()
123-
if config.compile.enable:
124-
model = self._compile_model(model, config.compile.backend)
125113
self.model = model
126114
self.model_parts = [model]
127115

@@ -225,6 +213,7 @@ def _build_model(
225213
model,
226214
parallel_dims=self.parallel_dims,
227215
parallelism=config.parallelism,
216+
compile_config=config.compile,
228217
)
229218

230219
model.to_empty(device=device_type)
@@ -236,20 +225,6 @@ def _build_model(
236225

237226
return model
238227

239-
def _compile_model(self, model: torch.nn.Module, backend: str) -> torch.nn.Module:
240-
"""Compile each transformer layer with torch.compile.
241-
242-
Args:
243-
model: The model whose layers will be compiled.
244-
backend: torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor').
245-
"""
246-
for layer_id in model.layers:
247-
model.layers[layer_id].compile(backend=backend, fullgraph=True)
248-
logger.info(
249-
f"Compiled {len(model.layers)} transformer layers with {backend} backend"
250-
)
251-
return model
252-
253228
@endpoint
254229
async def get_weights(self) -> dict:
255230
"""Get model weights for generator.

torchtitan/experiments/rl/unified/config_registry.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313

1414
from torchtitan.components.lr_scheduler import LRSchedulersContainer
1515
from torchtitan.components.optimizer import OptimizersContainer
16-
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
16+
from torchtitan.config.configs import CompileConfig, ParallelismConfig, TrainingConfig
1717
from torchtitan.experiments.rl.unified.actors.generator import (
1818
GeneratorCompileConfig,
1919
SamplingConfig,
2020
VLLMGenerator,
2121
)
22-
from torchtitan.experiments.rl.unified.actors.trainer import (
23-
PolicyTrainer,
24-
TrainerCompileConfig,
25-
)
22+
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
2623
from torchtitan.experiments.rl.unified.simple_grpo_sum_digits import RLTrainer
2724
from torchtitan.models.qwen3 import model_registry
2825

@@ -44,7 +41,7 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config:
4441
parallelism=ParallelismConfig(
4542
tensor_parallel_degree=2,
4643
),
47-
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
44+
compile=CompileConfig(enable=True, backend="aot_eager"),
4845
),
4946
generator=VLLMGenerator.Config(
5047
model_dtype="bfloat16",
@@ -84,7 +81,7 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
8481
parallelism=ParallelismConfig(
8582
tensor_parallel_degree=2,
8683
),
87-
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
84+
compile=CompileConfig(enable=True, backend="aot_eager"),
8885
),
8986
generator=VLLMGenerator.Config(
9087
model_dtype="bfloat16",
@@ -124,7 +121,7 @@ def rl_grpo_qwen3_debug() -> RLTrainer.Config:
124121
tensor_parallel_degree=1,
125122
data_parallel_replicate_degree=1,
126123
),
127-
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
124+
compile=CompileConfig(enable=True, backend="aot_eager"),
128125
),
129126
generator=VLLMGenerator.Config(
130127
compile=GeneratorCompileConfig(

torchtitan/experiments/rl/unified/models/parallelize.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import logging
1313

14+
import torch
1415
import torch.nn as nn
1516

1617
from torch.distributed.device_mesh import DeviceMesh
@@ -24,6 +25,7 @@
2425
)
2526

2627
from torchtitan.config import ParallelismConfig
28+
from torchtitan.config.configs import CompileConfig
2729
from torchtitan.distributed import ParallelDims
2830

2931
logger = logging.getLogger(__name__)
@@ -34,6 +36,7 @@ def parallelize_qwen3(
3436
*,
3537
parallel_dims: ParallelDims,
3638
parallelism: ParallelismConfig,
39+
compile_config: CompileConfig | None = None,
3740
has_position_id: bool = False,
3841
):
3942
"""
@@ -44,6 +47,8 @@ def parallelize_qwen3(
4447
TODO: Change to core torchtitan's Qwen3 parallel plan when full DTensor is ready
4548
4649
Args:
50+
compile_config: If provided and enabled, applies per-layer torch.compile
51+
after TP (matching the pattern in torchtitan/models/llama3/parallelize.py).
4752
has_position_id: Whether position IDs are passed as an explicit argument
4853
to the attention module. True for vLLM inference (generator),
4954
False for training (trainer).
@@ -60,9 +65,30 @@ def parallelize_qwen3(
6065
has_position_id=has_position_id,
6166
)
6267

68+
if (
69+
compile_config is not None
70+
and compile_config.enable
71+
and "model" in compile_config.components
72+
):
73+
apply_compile(model, compile_config)
74+
6375
return model
6476

6577

78+
def apply_compile(model: nn.Module, compile_config: CompileConfig):
79+
"""Apply torch.compile to each TransformerBlock.
80+
81+
Follows the same pattern as torchtitan/models/llama3/parallelize.py.
82+
"""
83+
for layer_id, transformer_block in model.layers.named_children():
84+
transformer_block = torch.compile(
85+
transformer_block, backend=compile_config.backend, fullgraph=True
86+
)
87+
model.layers.register_module(layer_id, transformer_block)
88+
89+
logger.info("Compiling each TransformerBlock with torch.compile")
90+
91+
6692
def apply_non_moe_tp(
6793
model: nn.Module,
6894
tp_mesh: DeviceMesh,

0 commit comments

Comments
 (0)