Skip to content

Commit 2c30ae1

Browse files
authored
Update README.md
1 parent 2310282 commit 2c30ae1

File tree

1 file changed

+167
-48
lines changed

1 file changed

+167
-48
lines changed

README.md

Lines changed: 167 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,151 @@
1+
#### MaxFactor is best described as a thoughtful integration of existing optimization techniques with specific implementation choices tailored for encoder-decoder ASR transformer models. It combines proven optimization techniques from several established algorithms, with implementation details specifically tuned for transformer architectures used in speech recognition. The optimizer makes practical engineering tradeoffs that work well empirically for speech recognition models. Its particular combination of approaches addresses practical challenges in training large speech and multimodal llms.
12

23

3-
```python
4+
#### MaxFactor Family Tree
5+
6+
```
7+
Adam
8+
├── Adaptive learning rates
9+
└── EMA of second moments
10+
11+
Adafactor
12+
├── Factorized second moments
13+
└── Relative step sizing
14+
15+
SignSGD
16+
└── Sign-based updates
17+
18+
LAMB/LARS
19+
├── Layer-wise adaptivity
20+
└── Gradient normalization
21+
22+
AdamW
23+
└── Decoupled weight decay
24+
25+
Adamax
26+
└── Infinity normalization
27+
28+
RMSprop
29+
└── Root mean squared gradient scaling
30+
31+
Gradient Clipping
32+
└── Max norm constraints
33+
34+
MaxFactor
35+
└── Combines all above features with a couple unique twists. (and FAM)
36+
```
37+
Coming soon -
38+
39+
## Frequency-Adaptive Momentum (FAM)
40+
41+
### Core Concept
42+
43+
- Speech signals have inherent frequency structure, with different parts of the model responding to different frequency bands. The frequency structure of speech doesn't just disappear when converted to log-mel spectrograms; it's transformed and preserved in ways that the model's parameters adapt to capture.
44+
- The Chain of Frequency Information: Original Audio → Log-Mel Spectrogram → Encoder Parameters → Gradient Updates.
45+
This isn't just a theoretical connection - it's empirically observable in how transformer-based speech models learn:
46+
- Lower encoder layers develop filters that respond to specific frequency bands in the mel spectrogram.
47+
- Attention heads specialize in tracking particular acoustic patterns across time.
48+
- The model inherently develops a hierarchical representation from acoustic features to phonetic units to words.
49+
- The idea is to try and integrate a momentum scheme that adapts based on the "frequency signature" of gradient updates.
450

51+
### Why This Optimizer Makes Sense
552

6-
class MaxFactor(Optimizer):
7-
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0,
8-
weight_decay=0.01, gamma=0.99, eps_rms=1e-8, maximize=False):
53+
What's compelling about the Frequency-Adaptive Momentum approach is that it acknowledges this structure in the optimization process itself. Rather than treating all gradient dimensions equally, it recognizes that:
54+
- **Gradient Frequencies Matter:** The Fourier transform of gradient updates reveals patterns related to what the model is currently learning.
55+
- **Different Parameters Process Different Bands:** Just as our ears have frequency-specific receptors, different parts of the model specialize in different acoustic frequencies.
56+
- **Temporal Structure in Learning:** Speech learning happens in stages - first basic acoustics, then phonetic patterns, then linguistic structures.
57+
58+
By applying different momentum factors to different frequency bands in parameter space, we're essentially giving the optimizer information about the audio domain that it wouldn't otherwise have.
59+
60+
61+
```python
62+
63+
class MaxFactor(torch.optim.Optimizer):
64+
"""
65+
MaxFactor optimizer that combines adaptive learning rates with factorized second moments.
66+
67+
Args:
68+
params (iterable): Iterable of parameters to optimize
69+
lr (float, optional): Maximum learning rate (default: 0.01)
70+
beta2_decay (float, optional): Decay exponent for second moments (default: -0.8)
71+
eps (tuple, optional): Small constants for numerical stability (default: (None, 1e-3))
72+
d (float, optional): Scaling factor for updates (default: 1.0)
73+
weight_decay (float, optional): Weight decay factor (default: 0.0)
74+
gamma (float, optional): EMA factor for non-matrix parameters (default: 0.99)
75+
max (bool, optional): Maximize the objective instead of minimizing (default: False)
76+
full_matrix (bool, optional): Use full matrix for second moments (default: False)
77+
clip (float, optional): Gradient clipping norm (default: 1.0)
78+
"""
79+
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-12, 1e-8), d=1.0,
80+
weight_decay=0.0, gamma=0.99, max=False,
81+
ull_matrix=False, clip=1.0):
82+
83+
if lr <= 0.0:
84+
raise ValueError(f"Learning rate must be positive, got {lr}")
85+
86+
eps1, eps2 = eps
87+
if eps1 is None:
88+
eps1 = torch.finfo(torch.float32).eps
89+
90+
defaults = dict(
91+
lr=lr, beta2_decay=beta2_decay, eps=(eps1, eps2), d=d, weight_decay=weight_decay,
92+
gamma=gamma, max=max, full_matrix=full_matrix, clip=clip)
993

10-
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
11-
gamma=gamma, eps_rms=eps_rms, maximize=maximize)
1294
super().__init__(params=params, defaults=defaults)
13-
95+
1496
def _get_lr(self, param_group, param_state):
1597
step = param_state["step"]
1698
step_float = step.item()
17-
decay_factor = min(1.0, 1.0 / (step_float ** 0.5 + 1e-8))
99+
decay_factor = min(1.0, 1.0 / (step_float ** 0.4 + 1e-12))
18100
param_scale = max(param_group["eps"][1], param_state["RMS"])
19101
return min(param_group["lr"], param_scale * decay_factor)
20102

21103
@staticmethod
22104
def _rms(tensor):
23-
return tensor.norm() / (tensor.numel() ** 0.5)
105+
if tensor.numel() == 0:
106+
return torch.tensor(0.0, device=tensor.device)
107+
return tensor.norm() / (tensor.numel() ** 0.5 + 1e-12)
24108

25109
@torch.no_grad()
26110
def step(self, closure=None):
111+
27112
loss = None
28113
if closure is not None:
29114
with torch.enable_grad():
30115
loss = closure()
31116

32117
for group in self.param_groups:
33-
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
118+
params_with_grad = []
119+
grads = []
120+
row_vars = []
121+
col_vars = []
122+
v = []
123+
state_steps = []
34124
eps1, eps2 = group["eps"]
35125
for p in group["params"]:
36126
if p.grad is None:
37127
continue
128+
38129
grad = p.grad
39130
if grad.dtype in {torch.float16, torch.bfloat16}:
40131
grad = grad.float()
41132

42133
state = self.state[p]
43134
if len(state) == 0:
44135
state["step"] = torch.tensor(0.0, dtype=torch.float32)
45-
if p.grad.dim() > 1:
46-
row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
47-
row_shape[-1], col_shape[-2] = 1, 1
48-
state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
136+
137+
if p.dim() > 1 and not group["full_matrix"]:
138+
row_shape = list(p.shape)
139+
row_shape[-1] = 1
140+
state["row_var"] = torch.zeros(row_shape, dtype=torch.float32, device=p.device)
141+
142+
col_shape = list(p.shape)
143+
col_shape[-2] = 1
144+
state["col_var"] = torch.zeros(col_shape, dtype=torch.float32, device=p.device)
145+
49146
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
147+
148+
state["RMS"] = self._rms(p).item()
50149

51150
row_vars.append(state.get("row_var", None))
52151
col_vars.append(state.get("col_var", None))
@@ -57,59 +156,79 @@ class MaxFactor(Optimizer):
57156

58157
for i, param in enumerate(params_with_grad):
59158
grad = grads[i]
60-
61-
if group["maximize"]:
159+
state = self.state[param]
160+
161+
# if self.use_fam and param.dim() > 1:
162+
# grad = frequency_adaptive_momentum(
163+
# grad,
164+
# state,
165+
# alpha=self.fam_alpha,
166+
# beta=self.fam_beta
167+
# )
168+
169+
if group["max"]:
62170
grad = -grad
63-
step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
64-
65-
if eps1 is None:
66-
eps1 = torch.finfo(param.dtype).eps
67171

172+
step_t = state_steps[i]
173+
row_var = row_vars[i]
174+
col_var = col_vars[i]
175+
vi = v[i]
176+
68177
step_t += 1
69178
step_float = step_t.item()
70179

71-
one_minus_beta2_t = step_float ** group["beta2_decay"]
180+
one_minus_beta2_t = min(0.999, step_float ** group["beta2_decay"])
181+
182+
state = self.state[param]
72183
state["RMS"] = self._rms(param).item()
73184
adaptive_lr = self._get_lr(group, state)
74-
rho_t = min(group["lr"], 1 / (step_float ** 0.5))
75-
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
76-
185+
77186
if group["weight_decay"] != 0:
78-
param.mul_(1 - group["lr"] * group["weight_decay"])
187+
param.mul_(1 - group["lr"] * group["weight_decay"] + eps1)
79188

80-
if grad.dim() > 1:
81-
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
189+
if param.dim() > 1 and not group["full_matrix"]:
190+
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_()
191+
row_mean.div_(grad.size(-1) + eps1)
82192
row_var.lerp_(row_mean, one_minus_beta2_t)
83-
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
193+
194+
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_()
195+
col_mean.div_(grad.size(-2) + eps1)
84196
col_var.lerp_(col_mean, one_minus_beta2_t)
197+
85198
var_estimate = row_var @ col_var
86199
max_row_var = row_var.max(dim=-2, keepdim=True)[0]
87200
var_estimate.div_(max_row_var.clamp_(min=eps1))
88201
else:
89-
vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
202+
203+
vi.mul_(group["gamma"]).add_(grad.square_(), alpha=1 - group["gamma"])
90204
var_estimate = vi
91-
92-
93-
205+
94206
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
95-
update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
96-
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
97207

98-
param.add_(-adaptive_lr / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
99-
return loss
100-
208+
inf_norm = torch.norm(update, float('inf'))
209+
if inf_norm > 0:
210+
update.div_(inf_norm.clamp_(min=eps1))
211+
212+
if group.get("clip", 0) > 0:
213+
torch.nn.utils.clip_grad_norm_(
214+
[update],
215+
max_norm=group["clip"]
216+
)
217+
218+
l2_norm = update.norm(2).item()
219+
denom = max(1.0, l2_norm / ((update.numel() ** 0.5) * group["d"]))
220+
221+
if param.dim() > 1:
222+
param.add_(
223+
update.sign() * update.abs().max(dim=-1, keepdim=True)[0],
224+
alpha=-adaptive_lr / denom
225+
)
226+
else:
227+
param.add_(update, alpha=-adaptive_lr / denom)
228+
229+
state["step"] = step_t
101230

102-
optimizer = MaxFactor(
103-
model.parameters(),
104-
lr=0.01,
105-
beta2_decay=-0.8,
106-
eps=(1e-10, 1e-4),
107-
d=1.0,
108-
weight_decay=0.01,
109-
gamma=0.99,
110-
eps_rms=1e-8,
111-
maximize=False,
112-
)
231+
return loss
113232

114233
# optional:
115234
# Create scheduler with warmup

0 commit comments

Comments
 (0)