Skip to content

Commit 40e9cbb

Browse files
wang55rakkit
authored andcommitted
init scion
1 parent cd337db commit 40e9cbb

File tree

8 files changed

+2116
-14
lines changed

8 files changed

+2116
-14
lines changed

torchtitan/components/optimizer.py

Lines changed: 236 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import functools
8+
import os
9+
import re
10+
from collections import OrderedDict
811
from typing import Any, Generic, Iterator, TypeVar
912

1013
import torch
@@ -21,6 +24,9 @@
2124
from torchtitan.components.ft import FTManager, has_torchft
2225
from torchtitan.config import Optimizer as OptimizerConfig
2326
from torchtitan.distributed import ParallelDims
27+
from torchtitan.experiments.distributed_scion import DistributedScion, naive_param_norm
28+
from torchtitan.tools.logging import logger
29+
from torchtitan.tools.utils import Color
2430

2531
__all__ = [
2632
"OptimizersContainer",
@@ -36,6 +42,55 @@
3642
T = TypeVar("T", bound=Optimizer)
3743

3844

45+
def _extract_param_groups(
46+
model: torch.nn.Module,
47+
optimizer_config: dict[str, Any] | None = None,
48+
):
49+
param_groups_config: list[dict[str, Any]] | None = (
50+
optimizer_config.pop("param_groups", None)
51+
if optimizer_config is not None
52+
else None
53+
)
54+
if param_groups_config is None:
55+
param_groups_config = []
56+
57+
param_dict = OrderedDict(
58+
(n, p) for n, p in model.named_parameters() if p.requires_grad
59+
)
60+
params = []
61+
62+
color = Color()
63+
for param_group_config in param_groups_config:
64+
str_match = param_group_config.pop("param_str_match")
65+
filter_fn = functools.partial(re.search, str_match)
66+
param_names = [n for n in param_dict.keys() if filter_fn(n)]
67+
group_params = {
68+
"params": [param_dict.pop(n) for n in param_names],
69+
"param_names": param_names,
70+
}
71+
assert len(group_params["params"]) == len(group_params["param_names"])
72+
73+
if len(param_names) == 0:
74+
logger.warning(
75+
f'{color.red}Notice: No parameters found for `str_match` "{str_match}" on '
76+
f"global rank {torch.distributed.get_rank()}{color.reset}"
77+
)
78+
continue
79+
group_params.update(param_group_config)
80+
params.append(group_params)
81+
82+
param_names = list(param_dict.keys())
83+
params.insert(
84+
0,
85+
{
86+
"params": [param_dict.pop(n) for n in param_names],
87+
"param_names": param_names,
88+
},
89+
)
90+
assert not param_dict
91+
return params
92+
93+
3994
class OptimizersContainer(Optimizer, Stateful, Generic[T]):
4095
"""A container for multiple optimizers.
4196
@@ -74,11 +129,34 @@ def __init__(
74129
all_params = []
75130
self.optimizers = []
76131
self.model_parts = model_parts
132+
param_groups_config = optimizer_kwargs.get("param_groups", None)
133+
# Whether to keep old LR values when loading.
134+
self.preserve_lrs_when_loading = False
135+
self.norms_to_log: list[str] | None = None
136+
77137
for model in self.model_parts:
78-
params = [p for p in model.parameters() if p.requires_grad]
79-
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
138+
# copy parts we will pop from to preserve settings across model parts
139+
kwargs = optimizer_kwargs.copy()
140+
if "param_groups" in optimizer_kwargs:
141+
kwargs["param_groups"] = (
142+
param_groups_config.copy()
143+
if param_groups_config is not None
144+
else None
145+
)
146+
147+
extra_kwargs = kwargs.pop("extra_kwargs")
148+
params = _extract_param_groups(model, kwargs)
149+
150+
is_scion = issubclass(optimizer_cls, (DistributedScion))
151+
if is_scion:
152+
kwargs.update(extra_kwargs)
153+
self.optimizers.append(optimizer_cls(params, **kwargs))
80154
all_params.extend(params)
81155
self._validate_length(len(self.model_parts))
156+
# Do not separately save the external settings in
157+
# optimizer defaults.
158+
optimizer_kwargs.pop("param_groups", None)
159+
optimizer_kwargs.update(optimizer_kwargs.pop("extra_kwargs", {}))
82160
self._post_init(all_params, optimizer_kwargs)
83161

84162
def __iter__(self) -> Iterator[T]:
@@ -93,7 +171,12 @@ def step(self, *args, **kwargs) -> None:
93171

94172
def zero_grad(self, *args, **kwargs) -> None:
95173
for optimizer in self.optimizers:
96-
optimizer.zero_grad(*args, **kwargs)
174+
if not (
175+
isinstance(optimizer, (DistributedScion))
176+
and optimizer.is_light
177+
and optimizer.use_momentum
178+
):
179+
optimizer.zero_grad(*args, **kwargs)
97180

98181
def state_dict(self) -> dict[str, Any]:
99182
func = functools.partial(
@@ -107,13 +190,68 @@ def state_dict(self) -> dict[str, Any]:
107190
}
108191

109192
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
193+
if self.preserve_lrs_when_loading:
194+
# Store current learning rates
195+
prev_lrs = []
196+
for optimizer in self.optimizers:
197+
prev_lrs.append([group["lr"] for group in optimizer.param_groups])
198+
110199
func = functools.partial(
111200
set_optimizer_state_dict,
112201
optim_state_dict=state_dict,
113202
options=StateDictOptions(flatten_optimizer_state_dict=True),
114203
)
115204
list(map(func, self.model_parts, self.optimizers))
116205

206+
if self.preserve_lrs_when_loading:
207+
# Restore the original learning rates
208+
for optimizer, optim_prev_lrs in zip(self.optimizers, prev_lrs):
209+
for param_group, prev_lr in zip(optimizer.param_groups, optim_prev_lrs):
210+
if param_group["lr"] != prev_lr:
211+
logger.warning(
212+
f"Restoring lr from {param_group['lr']} to {prev_lr} | "
213+
f"for {param_group['param_names']}"
214+
)
215+
param_group["lr"] = prev_lr
216+
217+
def calculate_norm_at_next_step(self):
218+
# for Dist-scion, we tell the optimizer to calculate the norm at next step
219+
# in the step() function
220+
for i, _ in enumerate(self.model_parts):
221+
optimizer = self.optimizers[i]
222+
if isinstance(optimizer, DistributedScion):
223+
optimizer.calculate_norm_at_next_step(self.norms_to_log)
224+
225+
def get_parameter_norms(self):
226+
all_norms = {}
227+
for i, model_part in enumerate(self.model_parts):
228+
# NB: assumes correspondences between model parts and optimizers
229+
optimizer = self.optimizers[i]
230+
for group in optimizer.param_groups:
231+
if isinstance(optimizer, DistributedScion):
232+
all_norms.update(optimizer.get_norms_at_current_step())
233+
else:
234+
all_norms.update(
235+
naive_param_norm.get_parameter_norms(
236+
[model_part],
237+
[optimizer],
238+
self.norms_to_log,
239+
)
240+
)
241+
# # To Debug, we can force using naive_param_norm
242+
# all_norms.update(
243+
# naive_param_norm.get_parameter_norms([model_part], [optimizer])
244+
# )
245+
246+
return all_norms
247+
248+
def get_lrs(self):
249+
lrs = {}
250+
for i, optimizer in enumerate(self.optimizers):
251+
for k, group in enumerate(optimizer.param_groups):
252+
lrs[f"lr/opt_{i}/group_{k}"] = group["lr"]
253+
return lrs
254+
117255
def _validate_length(self, expected_length: int) -> None:
118256
assert expected_length == len(self.optimizers), (
119257
"Must pass one optimizer per model part or per param if "
@@ -246,6 +384,7 @@ def build_optimizers(
246384
optimizer_config: OptimizerConfig,
247385
parallel_dims: ParallelDims,
248386
ft_manager: FTManager | None = None,
387+
extra_kwargs: dict[str, Any] | None = None,
249388
) -> OptimizersContainer:
250389
"""Create a OptimizersContainer for the given model parts and job config.
251390
@@ -280,31 +419,114 @@ def build_optimizers(
280419
"TorchFT is not supported with optimizers in backward."
281420
)
282421

422+
extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
423+
283424
name = optimizer_config.name
284425
lr = optimizer_config.lr
285426
beta1 = optimizer_config.beta1
286427
beta2 = optimizer_config.beta2
287428
eps = optimizer_config.eps
288429
weight_decay = optimizer_config.weight_decay
289430

290-
optim_implementation = optimizer_config.implementation
291-
assert optim_implementation in ["fused", "foreach", "for-loop"]
431+
is_scion = name == "DistributedScion"
292432

293-
fused = optim_implementation == "fused"
294-
foreach = optim_implementation == "foreach"
433+
if name in ["Adam", "AdamW"]:
434+
optim_implementation = optimizer_config.implementation
435+
assert optim_implementation in ["fused", "foreach", "for-loop"]
295436

296-
optimizer_kwargs = {
297-
"lr": lr,
298-
"betas": (beta1, beta2),
299-
"eps": eps,
300-
"weight_decay": weight_decay,
301-
"fused": fused,
302-
"foreach": foreach,
437+
fused = optim_implementation == "fused"
438+
foreach = optim_implementation == "foreach"
439+
440+
if parallel_dims.ep_enabled:
441+
# Because for Expert Parallel, we have two different device meshes.
442+
fused, foreach = False, False
443+
444+
optimizer_kwargs = {
445+
"lr": lr,
446+
"betas": (beta1, beta2),
447+
"eps": eps,
448+
"weight_decay": weight_decay,
449+
"fused": fused,
450+
"foreach": foreach,
451+
}
452+
elif is_scion:
453+
backend_steps = optimizer_config.backend_steps
454+
zeropower_backend_algorithm = optimizer_config.zeropower_backend
455+
momentum = optimizer_config.momentum
456+
nesterov = optimizer_config.nesterov
457+
is_light = optimizer_config.is_light
458+
weight_decay = optimizer_config.weight_decay
459+
if os.environ.get("SCION_DEBUG_GRAD") == "1":
460+
# only if we want to debug the gradient, we dont run SVD
461+
norm_factor = "none"
462+
zeropower_backend_algorithm = "identity"
463+
logger.warning(
464+
'`SCION_DEBUG_GRAD` is set to 1, we will not run SVD and use the "identity" backend'
465+
)
466+
else:
467+
norm_factor = "spectral"
468+
469+
optimizer_kwargs = {
470+
"is_light": is_light,
471+
"weight_decay": weight_decay,
472+
"lr": lr,
473+
"momentum": momentum,
474+
"nesterov": nesterov,
475+
"eps": eps,
476+
"norm_factor": norm_factor,
477+
"backend": zeropower_backend_algorithm,
478+
"backend_steps": backend_steps,
479+
}
480+
else:
481+
raise NotImplementedError(f"Optimizer {name} not added.")
482+
483+
# Configure parameter group settings
484+
embed_lr = optimizer_config.embed_lr
485+
embed_str_match = optimizer_config.embed_str_match
486+
if embed_lr is not None and embed_str_match:
487+
param_groups_config = optimizer_kwargs.setdefault("param_groups", [])
488+
param_group_config = {
489+
"param_str_match": embed_str_match,
490+
"lr": embed_lr,
491+
}
492+
if is_scion:
493+
param_group_config["norm_factor"] = "embed_sqrt"
494+
param_group_config["backend"] = "identity"
495+
param_groups_config.append(param_group_config)
496+
unembed_lr = optimizer_config.unembed_lr
497+
unembed_str_match = optimizer_config.unembed_str_match
498+
if unembed_lr is not None and unembed_str_match:
499+
param_groups_config = optimizer_kwargs.setdefault("param_groups", [])
500+
param_group_config = {
501+
"param_str_match": unembed_str_match,
502+
"lr": unembed_lr,
503+
}
504+
if is_scion:
505+
param_group_config["norm_factor"] = "unembed_sqrt"
506+
param_group_config["backend"] = "identity"
507+
param_groups_config.append(param_group_config)
508+
509+
router_str_match = optimizer_config.router_str_match
510+
if router_str_match:
511+
param_groups_config = optimizer_kwargs.setdefault("param_groups", [])
512+
param_group_config = {
513+
"param_str_match": router_str_match,
514+
"lr": lr,
515+
}
516+
if is_scion:
517+
param_group_config["norm_factor"] = "spectral"
518+
param_group_config["backend"] = zeropower_backend_algorithm
519+
param_groups_config.append(param_group_config)
520+
521+
optimizer_kwargs["extra_kwargs"] = {
522+
"parallel_dims": parallel_dims,
523+
**extra_kwargs,
303524
}
304525

305526
optimizer_classes = {
306527
"Adam": torch.optim.Adam,
307528
"AdamW": torch.optim.AdamW,
529+
"DistributedScion": DistributedScion,
308530
}
309531
if name not in optimizer_classes:
310532
raise NotImplementedError(f"Optimizer {name} not added.")

torchtitan/config/job_config.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class Metrics:
6969
enable_wandb: bool = False
7070
"""Whether to log metrics to Weights & Biases"""
7171

72+
log_norm_freq: int = -1
73+
"""How often to log norms in iterations"""
74+
7275

7376
@dataclass
7477
class Model:
@@ -138,6 +141,37 @@ class Optimizer:
138141
register_post_accumulate_grad_hook after the optimizer is built.
139142
"""
140143

144+
# Below is Scion-specific configs
145+
is_light: bool = False
146+
"""Whether to use Scion's light (memory-saving) version"""
147+
148+
zeropower_backend: str = "newtonschulz5"
149+
"Which `zeropower_backend` to use."
150+
151+
backend_steps: int = 5
152+
"""Number of steps for the Scion backend"""
153+
154+
momentum: float = 0.95
155+
"""Scion momentum to use"""
156+
157+
nesterov: bool = False
158+
"""Whether to use Nesterov momentum in Scion"""
159+
160+
embed_lr: float | None = None
161+
"""Embedding layer learning rate"""
162+
163+
unembed_lr: float | None = None
164+
"""Unembedding layer learning rate"""
165+
166+
embed_str_match: str | None = None
167+
"""String to match for embedding layer parameter group"""
168+
169+
unembed_str_match: str | None = None
170+
"""String to match for unembedding layer parameter group"""
171+
172+
router_str_match: str | None = None
173+
"""String to match for MoE router layer parameter group"""
174+
141175

142176
@dataclass
143177
class LRScheduler:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .distributed_scion import DistributedScion # noqa: F401
8+
from .utils import remove_orig_mod_and_weight_for_p_name # noqa: F401

0 commit comments

Comments
 (0)