5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import functools
8
+ import os
9
+ import re
10
+ from collections import OrderedDict
8
11
from typing import Any , Generic , Iterator , TypeVar
9
12
10
13
import torch
21
24
from torchtitan .components .ft import FTManager , has_torchft
22
25
from torchtitan .config import Optimizer as OptimizerConfig
23
26
from torchtitan .distributed import ParallelDims
27
+ from torchtitan .experiments .distributed_scion import DistributedScion , naive_param_norm
28
+ from torchtitan .tools .logging import logger
29
+ from torchtitan .tools .utils import Color
24
30
25
31
__all__ = [
26
32
"OptimizersContainer" ,
36
42
T = TypeVar ("T" , bound = Optimizer )
37
43
38
44
45
+ def _extract_param_groups (
46
+ model : torch .nn .Module ,
47
+ optimizer_config : dict [str , Any ] | None = None ,
48
+ ):
49
+ param_groups_config : list [dict [str , Any ]] | None = (
50
+ optimizer_config .pop ("param_groups" , None )
51
+ if optimizer_config is not None
52
+ else None
53
+ )
54
+ if param_groups_config is None :
55
+ param_groups_config = []
56
+
57
+ param_dict = OrderedDict (
58
+ (n , p ) for n , p in model .named_parameters () if p .requires_grad
59
+ )
60
+ params = []
61
+
62
+ color = Color ()
63
+ for param_group_config in param_groups_config :
64
+ str_match = param_group_config .pop ("param_str_match" )
65
+ filter_fn = functools .partial (re .search , str_match )
66
+ param_names = [n for n in param_dict .keys () if filter_fn (n )]
67
+ group_params = {
68
+ "params" : [param_dict .pop (n ) for n in param_names ],
69
+ "param_names" : param_names ,
70
+ }
71
+ assert len (group_params ["params" ]) == len (group_params ["param_names" ])
72
+
73
+ if len (param_names ) == 0 :
74
+ logger .warning (
75
+ f'{ color .red } Notice: No parameters found for `str_match` "{ str_match } " on '
76
+ f"global rank { torch .distributed .get_rank ()} { color .reset } "
77
+ )
78
+ continue
79
+ group_params .update (param_group_config )
80
+ params .append (group_params )
81
+
82
+ param_names = list (param_dict .keys ())
83
+ params .insert (
84
+ 0 ,
85
+ {
86
+ "params" : [param_dict .pop (n ) for n in param_names ],
87
+ "param_names" : param_names ,
88
+ },
89
+ )
90
+ assert not param_dict
91
+ return params
92
+
93
+
39
94
class OptimizersContainer (Optimizer , Stateful , Generic [T ]):
40
95
"""A container for multiple optimizers.
41
96
@@ -74,11 +129,34 @@ def __init__(
74
129
all_params = []
75
130
self .optimizers = []
76
131
self .model_parts = model_parts
132
+ param_groups_config = optimizer_kwargs .get ("param_groups" , None )
133
+ # Whether to keep old LR values when loading.
134
+ self .preserve_lrs_when_loading = False
135
+ self .norms_to_log : list [str ] | None = None
136
+
77
137
for model in self .model_parts :
78
- params = [p for p in model .parameters () if p .requires_grad ]
79
- self .optimizers .append (optimizer_cls (params , ** optimizer_kwargs ))
138
+ # copy parts we will pop from to preserve settings across model parts
139
+ kwargs = optimizer_kwargs .copy ()
140
+ if "param_groups" in optimizer_kwargs :
141
+ kwargs ["param_groups" ] = (
142
+ param_groups_config .copy ()
143
+ if param_groups_config is not None
144
+ else None
145
+ )
146
+
147
+ extra_kwargs = kwargs .pop ("extra_kwargs" )
148
+ params = _extract_param_groups (model , kwargs )
149
+
150
+ is_scion = issubclass (optimizer_cls , (DistributedScion ))
151
+ if is_scion :
152
+ kwargs .update (extra_kwargs )
153
+ self .optimizers .append (optimizer_cls (params , ** kwargs ))
80
154
all_params .extend (params )
81
155
self ._validate_length (len (self .model_parts ))
156
+ # Do not separately save the external settings in
157
+ # optimizer defaults.
158
+ optimizer_kwargs .pop ("param_groups" , None )
159
+ optimizer_kwargs .update (optimizer_kwargs .pop ("extra_kwargs" , {}))
82
160
self ._post_init (all_params , optimizer_kwargs )
83
161
84
162
def __iter__ (self ) -> Iterator [T ]:
@@ -93,7 +171,12 @@ def step(self, *args, **kwargs) -> None:
93
171
94
172
def zero_grad (self , * args , ** kwargs ) -> None :
95
173
for optimizer in self .optimizers :
96
- optimizer .zero_grad (* args , ** kwargs )
174
+ if not (
175
+ isinstance (optimizer , (DistributedScion ))
176
+ and optimizer .is_light
177
+ and optimizer .use_momentum
178
+ ):
179
+ optimizer .zero_grad (* args , ** kwargs )
97
180
98
181
def state_dict (self ) -> dict [str , Any ]:
99
182
func = functools .partial (
@@ -107,13 +190,68 @@ def state_dict(self) -> dict[str, Any]:
107
190
}
108
191
109
192
def load_state_dict (self , state_dict : dict [str , Any ]) -> None :
193
+ if self .preserve_lrs_when_loading :
194
+ # Store current learning rates
195
+ prev_lrs = []
196
+ for optimizer in self .optimizers :
197
+ prev_lrs .append ([group ["lr" ] for group in optimizer .param_groups ])
198
+
110
199
func = functools .partial (
111
200
set_optimizer_state_dict ,
112
201
optim_state_dict = state_dict ,
113
202
options = StateDictOptions (flatten_optimizer_state_dict = True ),
114
203
)
115
204
list (map (func , self .model_parts , self .optimizers ))
116
205
206
+ if self .preserve_lrs_when_loading :
207
+ # Restore the original learning rates
208
+ for optimizer , optim_prev_lrs in zip (self .optimizers , prev_lrs ):
209
+ for param_group , prev_lr in zip (optimizer .param_groups , optim_prev_lrs ):
210
+ if param_group ["lr" ] != prev_lr :
211
+ logger .warning (
212
+ f"Restoring lr from { param_group ['lr' ]} to { prev_lr } | "
213
+ f"for { param_group ['param_names' ]} "
214
+ )
215
+ param_group ["lr" ] = prev_lr
216
+
217
+ def calculate_norm_at_next_step (self ):
218
+ # for Dist-scion, we tell the optimizer to calculate the norm at next step
219
+ # in the step() function
220
+ for i , _ in enumerate (self .model_parts ):
221
+ optimizer = self .optimizers [i ]
222
+ if isinstance (optimizer , DistributedScion ):
223
+ optimizer .calculate_norm_at_next_step (self .norms_to_log )
224
+
225
+ def get_parameter_norms (self ):
226
+ all_norms = {}
227
+ for i , model_part in enumerate (self .model_parts ):
228
+ # NB: assumes correspondences between model parts and optimizers
229
+ optimizer = self .optimizers [i ]
230
+ for group in optimizer .param_groups :
231
+ if isinstance (optimizer , DistributedScion ):
232
+ all_norms .update (optimizer .get_norms_at_current_step ())
233
+ else :
234
+ all_norms .update (
235
+ naive_param_norm .get_parameter_norms (
236
+ [model_part ],
237
+ [optimizer ],
238
+ self .norms_to_log ,
239
+ )
240
+ )
241
+ # # To Debug, we can force using naive_param_norm
242
+ # all_norms.update(
243
+ # naive_param_norm.get_parameter_norms([model_part], [optimizer])
244
+ # )
245
+
246
+ return all_norms
247
+
248
+ def get_lrs (self ):
249
+ lrs = {}
250
+ for i , optimizer in enumerate (self .optimizers ):
251
+ for k , group in enumerate (optimizer .param_groups ):
252
+ lrs [f"lr/opt_{ i } /group_{ k } " ] = group ["lr" ]
253
+ return lrs
254
+
117
255
def _validate_length (self , expected_length : int ) -> None :
118
256
assert expected_length == len (self .optimizers ), (
119
257
"Must pass one optimizer per model part or per param if "
@@ -246,6 +384,7 @@ def build_optimizers(
246
384
optimizer_config : OptimizerConfig ,
247
385
parallel_dims : ParallelDims ,
248
386
ft_manager : FTManager | None = None ,
387
+ extra_kwargs : dict [str , Any ] | None = None ,
249
388
) -> OptimizersContainer :
250
389
"""Create a OptimizersContainer for the given model parts and job config.
251
390
@@ -280,31 +419,114 @@ def build_optimizers(
280
419
"TorchFT is not supported with optimizers in backward."
281
420
)
282
421
422
+ extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
423
+
283
424
name = optimizer_config .name
284
425
lr = optimizer_config .lr
285
426
beta1 = optimizer_config .beta1
286
427
beta2 = optimizer_config .beta2
287
428
eps = optimizer_config .eps
288
429
weight_decay = optimizer_config .weight_decay
289
430
290
- optim_implementation = optimizer_config .implementation
291
- assert optim_implementation in ["fused" , "foreach" , "for-loop" ]
431
+ is_scion = name == "DistributedScion"
292
432
293
- fused = optim_implementation == "fused"
294
- foreach = optim_implementation == "foreach"
433
+ if name in ["Adam" , "AdamW" ]:
434
+ optim_implementation = optimizer_config .implementation
435
+ assert optim_implementation in ["fused" , "foreach" , "for-loop" ]
295
436
296
- optimizer_kwargs = {
297
- "lr" : lr ,
298
- "betas" : (beta1 , beta2 ),
299
- "eps" : eps ,
300
- "weight_decay" : weight_decay ,
301
- "fused" : fused ,
302
- "foreach" : foreach ,
437
+ fused = optim_implementation == "fused"
438
+ foreach = optim_implementation == "foreach"
439
+
440
+ if parallel_dims .ep_enabled :
441
+ # Because for Expert Parallel, we have two different device meshes.
442
+ fused , foreach = False , False
443
+
444
+ optimizer_kwargs = {
445
+ "lr" : lr ,
446
+ "betas" : (beta1 , beta2 ),
447
+ "eps" : eps ,
448
+ "weight_decay" : weight_decay ,
449
+ "fused" : fused ,
450
+ "foreach" : foreach ,
451
+ }
452
+ elif is_scion :
453
+ backend_steps = optimizer_config .backend_steps
454
+ zeropower_backend_algorithm = optimizer_config .zeropower_backend
455
+ momentum = optimizer_config .momentum
456
+ nesterov = optimizer_config .nesterov
457
+ is_light = optimizer_config .is_light
458
+ weight_decay = optimizer_config .weight_decay
459
+ if os .environ .get ("SCION_DEBUG_GRAD" ) == "1" :
460
+ # only if we want to debug the gradient, we dont run SVD
461
+ norm_factor = "none"
462
+ zeropower_backend_algorithm = "identity"
463
+ logger .warning (
464
+ '`SCION_DEBUG_GRAD` is set to 1, we will not run SVD and use the "identity" backend'
465
+ )
466
+ else :
467
+ norm_factor = "spectral"
468
+
469
+ optimizer_kwargs = {
470
+ "is_light" : is_light ,
471
+ "weight_decay" : weight_decay ,
472
+ "lr" : lr ,
473
+ "momentum" : momentum ,
474
+ "nesterov" : nesterov ,
475
+ "eps" : eps ,
476
+ "norm_factor" : norm_factor ,
477
+ "backend" : zeropower_backend_algorithm ,
478
+ "backend_steps" : backend_steps ,
479
+ }
480
+ else :
481
+ raise NotImplementedError (f"Optimizer { name } not added." )
482
+
483
+ # Configure parameter group settings
484
+ embed_lr = optimizer_config .embed_lr
485
+ embed_str_match = optimizer_config .embed_str_match
486
+ if embed_lr is not None and embed_str_match :
487
+ param_groups_config = optimizer_kwargs .setdefault ("param_groups" , [])
488
+ param_group_config = {
489
+ "param_str_match" : embed_str_match ,
490
+ "lr" : embed_lr ,
491
+ }
492
+ if is_scion :
493
+ param_group_config ["norm_factor" ] = "embed_sqrt"
494
+ param_group_config ["backend" ] = "identity"
495
+ param_groups_config .append (param_group_config )
496
+ unembed_lr = optimizer_config .unembed_lr
497
+ unembed_str_match = optimizer_config .unembed_str_match
498
+ if unembed_lr is not None and unembed_str_match :
499
+ param_groups_config = optimizer_kwargs .setdefault ("param_groups" , [])
500
+ param_group_config = {
501
+ "param_str_match" : unembed_str_match ,
502
+ "lr" : unembed_lr ,
503
+ }
504
+ if is_scion :
505
+ param_group_config ["norm_factor" ] = "unembed_sqrt"
506
+ param_group_config ["backend" ] = "identity"
507
+ param_groups_config .append (param_group_config )
508
+
509
+ router_str_match = optimizer_config .router_str_match
510
+ if router_str_match :
511
+ param_groups_config = optimizer_kwargs .setdefault ("param_groups" , [])
512
+ param_group_config = {
513
+ "param_str_match" : router_str_match ,
514
+ "lr" : lr ,
515
+ }
516
+ if is_scion :
517
+ param_group_config ["norm_factor" ] = "spectral"
518
+ param_group_config ["backend" ] = zeropower_backend_algorithm
519
+ param_groups_config .append (param_group_config )
520
+
521
+ optimizer_kwargs ["extra_kwargs" ] = {
522
+ "parallel_dims" : parallel_dims ,
523
+ ** extra_kwargs ,
303
524
}
304
525
305
526
optimizer_classes = {
306
527
"Adam" : torch .optim .Adam ,
307
528
"AdamW" : torch .optim .AdamW ,
529
+ "DistributedScion" : DistributedScion ,
308
530
}
309
531
if name not in optimizer_classes :
310
532
raise NotImplementedError (f"Optimizer { name } not added." )
0 commit comments