Skip to content

Commit 48acd7a

Browse files
authored
Update MaxFactor.py
1 parent 2c30ae1 commit 48acd7a

File tree

1 file changed

+109
-34
lines changed

1 file changed

+109
-34
lines changed

MaxFactor/MaxFactor.py

Lines changed: 109 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,89 @@
1-
class MaxFactor(Optimizer):
2-
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0,
3-
weight_decay=0.01, gamma=0.99, eps_rms=1e-8, maximize=False):
1+
class MaxFactor(torch.optim.Optimizer):
2+
"""
3+
MaxFactor optimizer that combines adaptive learning rates with factorized second moments.
4+
5+
Args:
6+
params (iterable): Iterable of parameters to optimize
7+
lr (float, optional): Maximum learning rate (default: 0.01)
8+
beta2_decay (float, optional): Decay exponent for second moments (default: -0.8)
9+
eps (tuple, optional): Small constants for numerical stability (default: (None, 1e-3))
10+
d (float, optional): Scaling factor for updates (default: 1.0)
11+
weight_decay (float, optional): Weight decay factor (default: 0.0)
12+
gamma (float, optional): EMA factor for non-matrix parameters (default: 0.99)
13+
max (bool, optional): Maximize the objective instead of minimizing (default: False)
14+
full_matrix (bool, optional): Use full matrix for second moments (default: False)
15+
clip (float, optional): Gradient clipping norm (default: 1.0)
16+
"""
17+
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-12, 1e-8), d=1.0,
18+
weight_decay=0.0, gamma=0.99, max=False,
19+
ull_matrix=False, clip=1.0):
20+
21+
if lr <= 0.0:
22+
raise ValueError(f"Learning rate must be positive, got {lr}")
23+
24+
eps1, eps2 = eps
25+
if eps1 is None:
26+
eps1 = torch.finfo(torch.float32).eps
27+
28+
defaults = dict(
29+
lr=lr, beta2_decay=beta2_decay, eps=(eps1, eps2), d=d, weight_decay=weight_decay,
30+
gamma=gamma, max=max, full_matrix=full_matrix, clip=clip)
431

5-
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
6-
gamma=gamma, eps_rms=eps_rms, maximize=maximize)
732
super().__init__(params=params, defaults=defaults)
8-
33+
934
def _get_lr(self, param_group, param_state):
1035
step = param_state["step"]
1136
step_float = step.item()
12-
decay_factor = min(1.0, 1.0 / (step_float ** 0.5 + 1e-8))
37+
decay_factor = min(1.0, 1.0 / (step_float ** 0.4 + 1e-12))
1338
param_scale = max(param_group["eps"][1], param_state["RMS"])
1439
return min(param_group["lr"], param_scale * decay_factor)
1540

1641
@staticmethod
1742
def _rms(tensor):
18-
return tensor.norm() / (tensor.numel() ** 0.5)
43+
if tensor.numel() == 0:
44+
return torch.tensor(0.0, device=tensor.device)
45+
return tensor.norm() / (tensor.numel() ** 0.5 + 1e-12)
1946

2047
@torch.no_grad()
2148
def step(self, closure=None):
49+
2250
loss = None
2351
if closure is not None:
2452
with torch.enable_grad():
2553
loss = closure()
2654

2755
for group in self.param_groups:
28-
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
56+
params_with_grad = []
57+
grads = []
58+
row_vars = []
59+
col_vars = []
60+
v = []
61+
state_steps = []
2962
eps1, eps2 = group["eps"]
3063
for p in group["params"]:
3164
if p.grad is None:
3265
continue
66+
3367
grad = p.grad
3468
if grad.dtype in {torch.float16, torch.bfloat16}:
3569
grad = grad.float()
3670

3771
state = self.state[p]
3872
if len(state) == 0:
3973
state["step"] = torch.tensor(0.0, dtype=torch.float32)
40-
if p.grad.dim() > 1:
41-
row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
42-
row_shape[-1], col_shape[-2] = 1, 1
43-
state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
74+
75+
if p.dim() > 1 and not group["full_matrix"]:
76+
row_shape = list(p.shape)
77+
row_shape[-1] = 1
78+
state["row_var"] = torch.zeros(row_shape, dtype=torch.float32, device=p.device)
79+
80+
col_shape = list(p.shape)
81+
col_shape[-2] = 1
82+
state["col_var"] = torch.zeros(col_shape, dtype=torch.float32, device=p.device)
83+
4484
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
85+
86+
state["RMS"] = self._rms(p).item()
4587

4688
row_vars.append(state.get("row_var", None))
4789
col_vars.append(state.get("col_var", None))
@@ -52,43 +94,76 @@ def step(self, closure=None):
5294

5395
for i, param in enumerate(params_with_grad):
5496
grad = grads[i]
55-
56-
if group["maximize"]:
97+
state = self.state[param]
98+
99+
# if self.use_fam and param.dim() > 1:
100+
# grad = frequency_adaptive_momentum(
101+
# grad,
102+
# state,
103+
# alpha=self.fam_alpha,
104+
# beta=self.fam_beta
105+
# )
106+
107+
if group["max"]:
57108
grad = -grad
58-
step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
59-
60-
if eps1 is None:
61-
eps1 = torch.finfo(param.dtype).eps
62109

110+
step_t = state_steps[i]
111+
row_var = row_vars[i]
112+
col_var = col_vars[i]
113+
vi = v[i]
114+
63115
step_t += 1
64116
step_float = step_t.item()
65117

66-
one_minus_beta2_t = step_float ** group["beta2_decay"]
118+
one_minus_beta2_t = min(0.999, step_float ** group["beta2_decay"])
119+
120+
state = self.state[param]
67121
state["RMS"] = self._rms(param).item()
68122
adaptive_lr = self._get_lr(group, state)
69-
rho_t = min(group["lr"], 1 / (step_float ** 0.5))
70-
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
71-
123+
72124
if group["weight_decay"] != 0:
73-
param.mul_(1 - group["lr"] * group["weight_decay"])
125+
param.mul_(1 - group["lr"] * group["weight_decay"] + eps1)
74126

75-
if grad.dim() > 1:
76-
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
127+
if param.dim() > 1 and not group["full_matrix"]:
128+
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_()
129+
row_mean.div_(grad.size(-1) + eps1)
77130
row_var.lerp_(row_mean, one_minus_beta2_t)
78-
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
131+
132+
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_()
133+
col_mean.div_(grad.size(-2) + eps1)
79134
col_var.lerp_(col_mean, one_minus_beta2_t)
135+
80136
var_estimate = row_var @ col_var
81137
max_row_var = row_var.max(dim=-2, keepdim=True)[0]
82138
var_estimate.div_(max_row_var.clamp_(min=eps1))
83139
else:
84-
vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
140+
141+
vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"])
85142
var_estimate = vi
86-
87-
88-
143+
89144
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
90-
update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
91-
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
92145

93-
param.add_(-adaptive_lr / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
146+
inf_norm = torch.norm(update, float('inf'))
147+
if inf_norm > 0:
148+
update.div_(inf_norm.clamp_(min=eps1))
149+
150+
if group.get("clip", 0) > 0:
151+
torch.nn.utils.clip_grad_norm_(
152+
[update],
153+
max_norm=group["clip"]
154+
)
155+
156+
l2_norm = update.norm(2).item()
157+
denom = max(1.0, l2_norm / ((update.numel() ** 0.5) * group["d"]))
158+
159+
if param.dim() > 1:
160+
param.add_(
161+
update.sign() * update.abs().max(dim=-1, keepdim=True)[0],
162+
alpha=-adaptive_lr / denom
163+
)
164+
else:
165+
param.add_(update, alpha=-adaptive_lr / denom)
166+
167+
state["step"] = step_t
168+
94169
return loss

0 commit comments

Comments
 (0)