Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion torchtitan/experiments/autoparallel/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from collections.abc import Callable
from dataclasses import dataclass, field, fields

from torchtitan.config import ActivationCheckpointConfig
from torchtitan.config.configs import CompileConfig
from torchtitan.protocols.model_spec import ModelSpec
from torchtitan.trainer import Trainer


Expand All @@ -23,3 +26,30 @@ class AutoParallelConfig(Trainer.Config):
compile: AutoParallelCompileConfig = field(
default_factory=AutoParallelCompileConfig
)


def to_autoparallel_config(
base_config: Trainer.Config,
model_registry: Callable[[str], ModelSpec],
flavor: str | None = None,
) -> AutoParallelConfig:
"""Convert a base Trainer.Config to an AutoParallelConfig.

Copies all fields from the base config and replaces the model_spec with one
from the autoparallel model_registry. The compile field is removed and
left as the AutoParallelConfig default; callers should explicitly set it.

Args:
base_config: The base Trainer.Config to convert.
model_registry: A callable that returns a ModelSpec for a given flavor.
flavor: Optional flavor override. If None, uses the base config's flavor.
"""
d = {f.name: getattr(base_config, f.name) for f in fields(base_config)}
d["model_spec"] = model_registry(flavor or base_config.model_spec.flavor)
d.pop("compile")

ac = d.get("activation_checkpoint")
if ac is not None and ac.mode != "none":
d["activation_checkpoint"] = ActivationCheckpointConfig(mode="selective")

return AutoParallelConfig(**d)
Loading
Loading