Skip to content

Commit 72d4931

Browse files
Sanket Jayant Purandaresanketpurandare
authored andcommitted
Add graph PP infrastructure for autoparallel
stack-info: PR: #2726, branch: sanketpurandare/stack/5
1 parent cd1af1a commit 72d4931

File tree

3 files changed

+669
-1
lines changed

3 files changed

+669
-1
lines changed

torchtitan/experiments/autoparallel/configs.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from dataclasses import dataclass, field
7+
from collections.abc import Callable
8+
from dataclasses import dataclass, field, fields
89

10+
from torchtitan.config import ActivationCheckpointConfig
911
from torchtitan.config.configs import CompileConfig
12+
from torchtitan.protocols.model_spec import ModelSpec
1013
from torchtitan.trainer import Trainer
1114

1215

@@ -23,3 +26,30 @@ class AutoParallelConfig(Trainer.Config):
2326
compile: AutoParallelCompileConfig = field(
2427
default_factory=AutoParallelCompileConfig
2528
)
29+
30+
31+
def to_autoparallel_config(
32+
base_config: Trainer.Config,
33+
model_registry: Callable[[str], ModelSpec],
34+
flavor: str | None = None,
35+
) -> AutoParallelConfig:
36+
"""Convert a base Trainer.Config to an AutoParallelConfig.
37+
38+
Copies all fields from the base config and replaces the model_spec with one
39+
from the autoparallel model_registry. The compile field is removed and
40+
left as the AutoParallelConfig default; callers should explicitly set it.
41+
42+
Args:
43+
base_config: The base Trainer.Config to convert.
44+
model_registry: A callable that returns a ModelSpec for a given flavor.
45+
flavor: Optional flavor override. If None, uses the base config's flavor.
46+
"""
47+
d = {f.name: getattr(base_config, f.name) for f in fields(base_config)}
48+
d["model_spec"] = model_registry(flavor or base_config.model_spec.flavor)
49+
d.pop("compile")
50+
51+
ac = d.get("activation_checkpoint")
52+
if ac is not None and ac.mode != "none":
53+
d["activation_checkpoint"] = ActivationCheckpointConfig(mode="selective")
54+
55+
return AutoParallelConfig(**d)

0 commit comments

Comments
 (0)