3838logger = logging .getLogger (__name__ )
3939
4040
41+ @dataclass (kw_only = True , slots = True )
42+ class TrainerCompileConfig :
43+ """Compilation settings for the PolicyTrainer."""
44+
45+ enable : bool = False
46+ """Enable per-layer torch.compile on the training model."""
47+ backend : str = "eager"
48+ """torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor')."""
49+
50+
4151class PolicyTrainer (Actor , Configurable ):
4252 """
4353 Updates policy based on collected Episode using TorchTitan components.
@@ -64,6 +74,7 @@ class Config(Configurable.Config):
6474 parallelism : ParallelismConfig = field (default_factory = ParallelismConfig )
6575 comm : CommConfig = field (default_factory = CommConfig )
6676 """Communication configuration for distributed initialization."""
77+ compile : TrainerCompileConfig = field (default_factory = TrainerCompileConfig )
6778
6879 def __init__ (
6980 self ,
@@ -109,6 +120,8 @@ def __init__(
109120 model_spec , config , device_type , batch_invariant_mode , hf_assets_path
110121 )
111122 model .train ()
123+ if config .compile .enable :
124+ model = self ._compile_model (model , config .compile .backend )
112125 self .model = model
113126 self .model_parts = [model ]
114127
@@ -223,6 +236,20 @@ def _build_model(
223236
224237 return model
225238
239+ def _compile_model (self , model : torch .nn .Module , backend : str ) -> torch .nn .Module :
240+ """Compile each transformer layer with torch.compile.
241+
242+ Args:
243+ model: The model whose layers will be compiled.
244+ backend: torch.compile backend (e.g. 'eager', 'aot_eager', 'inductor').
245+ """
246+ for layer_id in model .layers :
247+ model .layers [layer_id ].compile (backend = backend , fullgraph = True )
248+ logger .info (
249+ f"Compiled { len (model .layers )} transformer layers with { backend } backend"
250+ )
251+ return model
252+
226253 @endpoint
227254 async def get_weights (self ) -> dict :
228255 """Get model weights for generator.
0 commit comments