44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from dataclasses import dataclass , field
7+ from collections .abc import Callable
8+ from dataclasses import dataclass , field , fields
89
10+ from torchtitan .config import ActivationCheckpointConfig
911from torchtitan .config .configs import CompileConfig
12+ from torchtitan .protocols .model_spec import ModelSpec
1013from torchtitan .trainer import Trainer
1114
1215
@@ -23,3 +26,30 @@ class AutoParallelConfig(Trainer.Config):
2326 compile : AutoParallelCompileConfig = field (
2427 default_factory = AutoParallelCompileConfig
2528 )
29+
30+
31+ def to_autoparallel_config (
32+ base_config : Trainer .Config ,
33+ model_registry : Callable [[str ], ModelSpec ],
34+ flavor : str | None = None ,
35+ ) -> AutoParallelConfig :
36+ """Convert a base Trainer.Config to an AutoParallelConfig.
37+
38+ Copies all fields from the base config and replaces the model_spec with one
39+ from the autoparallel model_registry. The compile field is removed and
40+ left as the AutoParallelConfig default; callers should explicitly set it.
41+
42+ Args:
43+ base_config: The base Trainer.Config to convert.
44+ model_registry: A callable that returns a ModelSpec for a given flavor.
45+ flavor: Optional flavor override. If None, uses the base config's flavor.
46+ """
47+ d = {f .name : getattr (base_config , f .name ) for f in fields (base_config )}
48+ d ["model_spec" ] = model_registry (flavor or base_config .model_spec .flavor )
49+ d .pop ("compile" )
50+
51+ ac = d .get ("activation_checkpoint" )
52+ if ac is not None and ac .mode != "none" :
53+ d ["activation_checkpoint" ] = ActivationCheckpointConfig (mode = "selective" )
54+
55+ return AutoParallelConfig (** d )
0 commit comments