Skip to content

Commit 0586ffd

Browse files
wang55rakkit
authored andcommitted
init scion
1 parent cd337db commit 0586ffd

File tree

11 files changed

+3049
-4
lines changed

11 files changed

+3049
-4
lines changed

torchtitan/components/optimizer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchtitan.components.ft import FTManager, has_torchft
2222
from torchtitan.config import Optimizer as OptimizerConfig
2323
from torchtitan.distributed import ParallelDims
24+
from torchtitan.experiments import distributed_scion
2425

2526
__all__ = [
2627
"OptimizersContainer",
@@ -75,7 +76,12 @@ def __init__(
7576
self.optimizers = []
7677
self.model_parts = model_parts
7778
for model in self.model_parts:
78-
params = [p for p in model.parameters() if p.requires_grad]
79+
if issubclass(optimizer_cls, (distributed_scion.DistributedScion)):
80+
params, optimizer_kwargs = distributed_scion.create_scion_param_groups(
81+
model, optimizer_kwargs
82+
)
83+
else:
84+
params = [p for p in model.parameters() if p.requires_grad]
7985
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
8086
all_params.extend(params)
8187
self._validate_length(len(self.model_parts))
@@ -302,9 +308,17 @@ def build_optimizers(
302308
"foreach": foreach,
303309
}
304310

311+
if name in ["DistributedScion"]:
312+
optimizer_kwargs = (
313+
distributed_scion.create_scion_optimizer_kwargs_from_optimizer_config(
314+
optimizer_config, parallel_dims
315+
)
316+
)
317+
305318
optimizer_classes = {
306319
"Adam": torch.optim.Adam,
307320
"AdamW": torch.optim.AdamW,
321+
"DistributedScion": distributed_scion.DistributedScion,
308322
}
309323
if name not in optimizer_classes:
310324
raise NotImplementedError(f"Optimizer {name} not added.")

torchtitan/config/job_config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class Metrics:
6969
enable_wandb: bool = False
7070
"""Whether to log metrics to Weights & Biases"""
7171

72+
log_norm_freq: int = -1
73+
"""How often to log norms in iterations"""
74+
7275

7376
@dataclass
7477
class Model:
@@ -122,6 +125,33 @@ class Optimizer:
122125
weight_decay: float = 0.1
123126
"""Weight decay to use"""
124127

128+
mup_width_multiplier: float = 1.0
129+
"""
130+
Width multiplier for the model to apply μP scaling (only used
131+
for Adam/Muon-based optimizers).
132+
"""
133+
134+
is_light: bool = False
135+
"""Whether to use Scion's light (memory-saving) version"""
136+
137+
norm_factor: str = "spectral"
138+
"""Which norm factor to use"""
139+
140+
zeropower_backend: str = "newtonschulz5"
141+
"Which `zeropower_backend` to use."
142+
143+
backend_steps: int = 5
144+
"""Number of steps for the Scion backend"""
145+
146+
momentum: float = 0.95
147+
"""Scion momentum to use"""
148+
149+
nesterov: bool = False
150+
"""Whether to use Nesterov momentum in Scion"""
151+
152+
extra_splits_rules: list[dict[str, Any]] | None = None
153+
"""Extra parameter group splitting rules for Scion optimizers"""
154+
125155
implementation: Literal["for-loop", "foreach", "fused"] = "fused"
126156
"""
127157
Specify which optimizer implementation to use:

torchtitan/distributed/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ def clip_grad_norm_(
388388
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
389389
total_norm **= 1.0 / norm_type
390390

391-
torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
391+
if max_norm > 0:
392+
torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
392393
return total_norm
393394

394395

@@ -444,7 +445,10 @@ def _clip_grad_norm_with_ep(
444445
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
445446
total_norm **= 1.0 / norm_type
446447

447-
torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach)
448-
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
448+
if max_norm > 0:
449+
torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach)
450+
torch.nn.utils.clip_grads_with_norm_(
451+
non_ep_params, max_norm, total_norm, foreach
452+
)
449453

450454
return total_norm
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .distributed_scion import DistributedScion # noqa: F401
8+
from .utils import ( # noqa: F401 # noqa: F401 # noqa: F401
9+
create_scion_optimizer_kwargs_from_optimizer_config,
10+
create_scion_param_groups,
11+
remove_orig_mod_and_weight_for_p_name,
12+
)

0 commit comments

Comments
 (0)