Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchtitan.components.ft import FTManager, has_torchft
from torchtitan.config import Optimizer as OptimizerConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments import distributed_scion

__all__ = [
"OptimizersContainer",
Expand Down Expand Up @@ -75,7 +76,12 @@ def __init__(
self.optimizers = []
self.model_parts = model_parts
for model in self.model_parts:
params = [p for p in model.parameters() if p.requires_grad]
if issubclass(optimizer_cls, (distributed_scion.DistributedScion)):
params, optimizer_kwargs = distributed_scion.create_scion_param_groups(
model, optimizer_kwargs
)
else:
params = [p for p in model.parameters() if p.requires_grad]
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
all_params.extend(params)
self._validate_length(len(self.model_parts))
Expand Down Expand Up @@ -302,9 +308,17 @@ def build_optimizers(
"foreach": foreach,
}

if name in ["DistributedScion"]:
optimizer_kwargs = (
distributed_scion.create_scion_optimizer_kwargs_from_optimizer_config(
optimizer_config, parallel_dims
)
)

optimizer_classes = {
"Adam": torch.optim.Adam,
"AdamW": torch.optim.AdamW,
"DistributedScion": distributed_scion.DistributedScion,
}
if name not in optimizer_classes:
raise NotImplementedError(f"Optimizer {name} not added.")
Expand Down
30 changes: 30 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class Metrics:
enable_wandb: bool = False
"""Whether to log metrics to Weights & Biases"""

log_norm_freq: int = -1
"""How often to log norms in iterations"""


@dataclass
class Model:
Expand Down Expand Up @@ -122,6 +125,33 @@ class Optimizer:
weight_decay: float = 0.1
"""Weight decay to use"""

mup_width_multiplier: float = 1.0
"""
Width multiplier for the model to apply μP scaling (only used
for Adam/Muon-based optimizers).
"""

is_light: bool = False
"""Whether to use Scion's light (memory-saving) version"""

norm_factor: str = "spectral"
"""Which norm factor to use"""

zeropower_backend: str = "newtonschulz5"
"Which `zeropower_backend` to use."

backend_steps: int = 5
"""Number of steps for the Scion backend"""

momentum: float = 0.95
"""Scion momentum to use"""

nesterov: bool = False
"""Whether to use Nesterov momentum in Scion"""

extra_splits_rules: list[dict[str, Any]] | None = None
"""Extra parameter group splitting rules for Scion optimizers"""

implementation: Literal["for-loop", "foreach", "fused"] = "fused"
"""
Specify which optimizer implementation to use:
Expand Down
10 changes: 7 additions & 3 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def clip_grad_norm_(
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type

torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
if max_norm > 0:
torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm


Expand Down Expand Up @@ -444,7 +445,10 @@ def _clip_grad_norm_with_ep(
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type

torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach)
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
if max_norm > 0:
torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach)
torch.nn.utils.clip_grads_with_norm_(
non_ep_params, max_norm, total_norm, foreach
)

return total_norm
12 changes: 12 additions & 0 deletions torchtitan/experiments/distributed_scion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .distributed_scion import DistributedScion # noqa: F401
from .utils import ( # noqa: F401 # noqa: F401 # noqa: F401
create_scion_optimizer_kwargs_from_optimizer_config,
create_scion_param_groups,
remove_orig_mod_and_weight_for_p_name,
)
Loading