1- class MaxFactor (Optimizer ):
2- def __init__ (self , params , lr = 0.01 , beta2_decay = - 0.8 , eps = (1e-10 , 1e-3 ), d = 1.0 ,
3- weight_decay = 0.01 , gamma = 0.99 , eps_rms = 1e-8 , maximize = False ):
1+ class MaxFactor (torch .optim .Optimizer ):
2+ """
3+ MaxFactor optimizer that combines adaptive learning rates with factorized second moments.
4+
5+ Args:
6+ params (iterable): Iterable of parameters to optimize
7+ lr (float, optional): Maximum learning rate (default: 0.01)
8+ beta2_decay (float, optional): Decay exponent for second moments (default: -0.8)
9+ eps (tuple, optional): Small constants for numerical stability (default: (None, 1e-3))
10+ d (float, optional): Scaling factor for updates (default: 1.0)
11+ weight_decay (float, optional): Weight decay factor (default: 0.0)
12+ gamma (float, optional): EMA factor for non-matrix parameters (default: 0.99)
13+ max (bool, optional): Maximize the objective instead of minimizing (default: False)
14+ full_matrix (bool, optional): Use full matrix for second moments (default: False)
15+ clip (float, optional): Gradient clipping norm (default: 1.0)
16+ """
17+ def __init__ (self , params , lr = 0.01 , beta2_decay = - 0.8 , eps = (1e-12 , 1e-8 ), d = 1.0 ,
18+ weight_decay = 0.0 , gamma = 0.99 , max = False ,
19+ ull_matrix = False , clip = 1.0 ):
20+
21+ if lr <= 0.0 :
22+ raise ValueError (f"Learning rate must be positive, got { lr } " )
23+
24+ eps1 , eps2 = eps
25+ if eps1 is None :
26+ eps1 = torch .finfo (torch .float32 ).eps
27+
28+ defaults = dict (
29+ lr = lr , beta2_decay = beta2_decay , eps = (eps1 , eps2 ), d = d , weight_decay = weight_decay ,
30+ gamma = gamma , max = max , full_matrix = full_matrix , clip = clip )
431
5- defaults = dict (lr = lr , beta2_decay = beta2_decay , eps = eps , d = d , weight_decay = weight_decay ,
6- gamma = gamma , eps_rms = eps_rms , maximize = maximize )
732 super ().__init__ (params = params , defaults = defaults )
8-
33+
934 def _get_lr (self , param_group , param_state ):
1035 step = param_state ["step" ]
1136 step_float = step .item ()
12- decay_factor = min (1.0 , 1.0 / (step_float ** 0.5 + 1e-8 ))
37+ decay_factor = min (1.0 , 1.0 / (step_float ** 0.4 + 1e-12 ))
1338 param_scale = max (param_group ["eps" ][1 ], param_state ["RMS" ])
1439 return min (param_group ["lr" ], param_scale * decay_factor )
1540
1641 @staticmethod
1742 def _rms (tensor ):
18- return tensor .norm () / (tensor .numel () ** 0.5 )
43+ if tensor .numel () == 0 :
44+ return torch .tensor (0.0 , device = tensor .device )
45+ return tensor .norm () / (tensor .numel () ** 0.5 + 1e-12 )
1946
2047 @torch .no_grad ()
2148 def step (self , closure = None ):
49+
2250 loss = None
2351 if closure is not None :
2452 with torch .enable_grad ():
2553 loss = closure ()
2654
2755 for group in self .param_groups :
28- params_with_grad , grads , row_vars , col_vars , v , state_steps = [], [], [], [], [], []
56+ params_with_grad = []
57+ grads = []
58+ row_vars = []
59+ col_vars = []
60+ v = []
61+ state_steps = []
2962 eps1 , eps2 = group ["eps" ]
3063 for p in group ["params" ]:
3164 if p .grad is None :
3265 continue
66+
3367 grad = p .grad
3468 if grad .dtype in {torch .float16 , torch .bfloat16 }:
3569 grad = grad .float ()
3670
3771 state = self .state [p ]
3872 if len (state ) == 0 :
3973 state ["step" ] = torch .tensor (0.0 , dtype = torch .float32 )
40- if p .grad .dim () > 1 :
41- row_shape , col_shape = list (p .grad .shape ), list (p .grad .shape )
42- row_shape [- 1 ], col_shape [- 2 ] = 1 , 1
43- state ["row_var" ], state ["col_var" ] = p .grad .new_zeros (row_shape ), p .grad .new_zeros (col_shape )
74+
75+ if p .dim () > 1 and not group ["full_matrix" ]:
76+ row_shape = list (p .shape )
77+ row_shape [- 1 ] = 1
78+ state ["row_var" ] = torch .zeros (row_shape , dtype = torch .float32 , device = p .device )
79+
80+ col_shape = list (p .shape )
81+ col_shape [- 2 ] = 1
82+ state ["col_var" ] = torch .zeros (col_shape , dtype = torch .float32 , device = p .device )
83+
4484 state ["v" ] = torch .zeros_like (p , memory_format = torch .preserve_format )
85+
86+ state ["RMS" ] = self ._rms (p ).item ()
4587
4688 row_vars .append (state .get ("row_var" , None ))
4789 col_vars .append (state .get ("col_var" , None ))
@@ -52,43 +94,76 @@ def step(self, closure=None):
5294
5395 for i , param in enumerate (params_with_grad ):
5496 grad = grads [i ]
55-
56- if group ["maximize" ]:
97+ state = self .state [param ]
98+
99+ # if self.use_fam and param.dim() > 1:
100+ # grad = frequency_adaptive_momentum(
101+ # grad,
102+ # state,
103+ # alpha=self.fam_alpha,
104+ # beta=self.fam_beta
105+ # )
106+
107+ if group ["max" ]:
57108 grad = - grad
58- step_t , row_var , col_var , vi = state_steps [i ], row_vars [i ], col_vars [i ], v [i ]
59-
60- if eps1 is None :
61- eps1 = torch .finfo (param .dtype ).eps
62109
110+ step_t = state_steps [i ]
111+ row_var = row_vars [i ]
112+ col_var = col_vars [i ]
113+ vi = v [i ]
114+
63115 step_t += 1
64116 step_float = step_t .item ()
65117
66- one_minus_beta2_t = step_float ** group ["beta2_decay" ]
118+ one_minus_beta2_t = min (0.999 , step_float ** group ["beta2_decay" ])
119+
120+ state = self .state [param ]
67121 state ["RMS" ] = self ._rms (param ).item ()
68122 adaptive_lr = self ._get_lr (group , state )
69- rho_t = min (group ["lr" ], 1 / (step_float ** 0.5 ))
70- alpha = max (eps2 , param .norm (2 ).item () / (param .numel () ** 0.5 )) * rho_t
71-
123+
72124 if group ["weight_decay" ] != 0 :
73- param .mul_ (1 - group ["lr" ] * group ["weight_decay" ])
125+ param .mul_ (1 - group ["lr" ] * group ["weight_decay" ] + eps1 )
74126
75- if grad .dim () > 1 :
76- row_mean = torch .norm (grad , dim = - 1 , keepdim = True ).square_ ().div_ (grad .size (- 1 ) + 1e-8 )
127+ if param .dim () > 1 and not group ["full_matrix" ]:
128+ row_mean = torch .norm (grad , dim = - 1 , keepdim = True ).square_ ()
129+ row_mean .div_ (grad .size (- 1 ) + eps1 )
77130 row_var .lerp_ (row_mean , one_minus_beta2_t )
78- col_mean = torch .norm (grad , dim = - 2 , keepdim = True ).square_ ().div_ (grad .size (- 2 ) + 1e-8 )
131+
132+ col_mean = torch .norm (grad , dim = - 2 , keepdim = True ).square_ ()
133+ col_mean .div_ (grad .size (- 2 ) + eps1 )
79134 col_var .lerp_ (col_mean , one_minus_beta2_t )
135+
80136 var_estimate = row_var @ col_var
81137 max_row_var = row_var .max (dim = - 2 , keepdim = True )[0 ]
82138 var_estimate .div_ (max_row_var .clamp_ (min = eps1 ))
83139 else :
84- vi .mul_ (group ["gamma" ]).add_ (grad ** 2 , alpha = 1 - group ["gamma" ])
140+
141+ vi .mul_ (group ["gamma" ]).add_ (grad .square_ (), alpha = 1 - group ["gamma" ])
85142 var_estimate = vi
86-
87-
88-
143+
89144 update = var_estimate .clamp_ (min = eps1 * eps1 ).rsqrt_ ().mul_ (grad )
90- update = update .div_ (torch .norm (update , float ('inf' )).clamp_ (min = eps1 ))
91- denom = max (1.0 , update .norm (2 ).item () / ((update .numel () ** 0.5 ) * group ["d" ]))
92145
93- param .add_ (- adaptive_lr / denom * update .sign () * update .abs ().max (dim = - 1 , keepdim = True )[0 ])
146+ inf_norm = torch .norm (update , float ('inf' ))
147+ if inf_norm > 0 :
148+ update .div_ (inf_norm .clamp_ (min = eps1 ))
149+
150+ if group .get ("clip" , 0 ) > 0 :
151+ torch .nn .utils .clip_grad_norm_ (
152+ [update ],
153+ max_norm = group ["clip" ]
154+ )
155+
156+ l2_norm = update .norm (2 ).item ()
157+ denom = max (1.0 , l2_norm / ((update .numel () ** 0.5 ) * group ["d" ]))
158+
159+ if param .dim () > 1 :
160+ param .add_ (
161+ update .sign () * update .abs ().max (dim = - 1 , keepdim = True )[0 ],
162+ alpha = - adaptive_lr / denom
163+ )
164+ else :
165+ param .add_ (update , alpha = - adaptive_lr / denom )
166+
167+ state ["step" ] = step_t
168+
94169 return loss
0 commit comments