Skip to content

Commit 42d870c

Browse files
committed
Changes to enable compilation
1 parent 00d6e6a commit 42d870c

File tree

3 files changed

+224
-174
lines changed

3 files changed

+224
-174
lines changed

torchtitan/experiments/rl/unified/actors/trainer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41+
@dataclass(kw_only=True, slots=True)
42+
class TrainerCompileConfig:
43+
"""Compilation settings for the PolicyTrainer."""
44+
45+
enable: bool = False
46+
"""Enable per-layer torch.compile on the training model."""
47+
backend: str = "eager"
48+
"""torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor')."""
49+
50+
4151
class PolicyTrainer(Actor, Configurable):
4252
"""
4353
Updates policy based on collected Episode using TorchTitan components.
@@ -64,6 +74,7 @@ class Config(Configurable.Config):
6474
parallelism: ParallelismConfig = field(default_factory=ParallelismConfig)
6575
comm: CommConfig = field(default_factory=CommConfig)
6676
"""Communication configuration for distributed initialization."""
77+
compile: TrainerCompileConfig = field(default_factory=TrainerCompileConfig)
6778

6879
def __init__(
6980
self,
@@ -109,6 +120,8 @@ def __init__(
109120
model_spec, config, device_type, batch_invariant_mode, hf_assets_path
110121
)
111122
model.train()
123+
if config.compile.enable:
124+
model = self._compile_model(model, config.compile.backend)
112125
self.model = model
113126
self.model_parts = [model]
114127

@@ -223,6 +236,20 @@ def _build_model(
223236

224237
return model
225238

239+
def _compile_model(self, model: torch.nn.Module, backend: str) -> torch.nn.Module:
240+
"""Compile each transformer layer with torch.compile.
241+
242+
Args:
243+
model: The model whose layers will be compiled.
244+
backend: torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor').
245+
"""
246+
for layer_id in model.layers:
247+
model.layers[layer_id].compile(backend=backend, fullgraph=True)
248+
logger.info(
249+
f"Compiled {len(model.layers)} transformer layers with {backend} backend"
250+
)
251+
return model
252+
226253
@endpoint
227254
async def get_weights(self) -> dict:
228255
"""Get model weights for generator.

torchtitan/experiments/rl/unified/config_registry.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
SamplingConfig,
2020
VLLMGenerator,
2121
)
22-
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
22+
from torchtitan.experiments.rl.unified.actors.trainer import (
23+
PolicyTrainer,
24+
TrainerCompileConfig,
25+
)
2326
from torchtitan.experiments.rl.unified.simple_grpo_sum_digits import RLTrainer
2427
from torchtitan.models.qwen3 import model_registry
2528

@@ -41,6 +44,7 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config:
4144
parallelism=ParallelismConfig(
4245
tensor_parallel_degree=2,
4346
),
47+
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
4448
),
4549
generator=VLLMGenerator.Config(
4650
model_dtype="bfloat16",
@@ -80,6 +84,7 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
8084
parallelism=ParallelismConfig(
8185
tensor_parallel_degree=2,
8286
),
87+
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
8388
),
8489
generator=VLLMGenerator.Config(
8590
model_dtype="bfloat16",
@@ -119,6 +124,7 @@ def rl_grpo_qwen3_debug() -> RLTrainer.Config:
119124
tensor_parallel_degree=1,
120125
data_parallel_replicate_degree=1,
121126
),
127+
compile=TrainerCompileConfig(enable=True, backend="aot_eager"),
122128
),
123129
generator=VLLMGenerator.Config(
124130
compile=GeneratorCompileConfig(

0 commit comments

Comments
 (0)