You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#### MaxFactor is best described as a thoughtful integration of existing optimization techniques with specific implementation choices tailored for encoder-decoder ASR transformer models. It combines proven optimization techniques from several established algorithms, with implementation details specifically tuned for transformer architectures used in speech recognition. The optimizer makes practical engineering tradeoffs that work well empirically for speech recognition models. Its particular combination of approaches addresses practical challenges in training large speech and multimodal llms.
1
2
2
3
3
-
```python
4
+
#### MaxFactor Family Tree
5
+
6
+
```
7
+
Adam
8
+
├── Adaptive learning rates
9
+
└── EMA of second moments
10
+
11
+
Adafactor
12
+
├── Factorized second moments
13
+
└── Relative step sizing
14
+
15
+
SignSGD
16
+
└── Sign-based updates
17
+
18
+
LAMB/LARS
19
+
├── Layer-wise adaptivity
20
+
└── Gradient normalization
21
+
22
+
AdamW
23
+
└── Decoupled weight decay
24
+
25
+
Adamax
26
+
└── Infinity normalization
27
+
28
+
RMSprop
29
+
└── Root mean squared gradient scaling
30
+
31
+
Gradient Clipping
32
+
└── Max norm constraints
33
+
34
+
MaxFactor
35
+
└── Combines all above features with a couple unique twists. (and FAM)
36
+
```
37
+
Coming soon -
38
+
39
+
## Frequency-Adaptive Momentum (FAM)
40
+
41
+
### Core Concept
42
+
43
+
- Speech signals have inherent frequency structure, with different parts of the model responding to different frequency bands. The frequency structure of speech doesn't just disappear when converted to log-mel spectrograms; it's transformed and preserved in ways that the model's parameters adapt to capture.
44
+
- The Chain of Frequency Information: Original Audio → Log-Mel Spectrogram → Encoder Parameters → Gradient Updates.
45
+
This isn't just a theoretical connection - it's empirically observable in how transformer-based speech models learn:
46
+
- Lower encoder layers develop filters that respond to specific frequency bands in the mel spectrogram.
47
+
- Attention heads specialize in tracking particular acoustic patterns across time.
48
+
- The model inherently develops a hierarchical representation from acoustic features to phonetic units to words.
49
+
- The idea is to try and integrate a momentum scheme that adapts based on the "frequency signature" of gradient updates.
What's compelling about the Frequency-Adaptive Momentum approach is that it acknowledges this structure in the optimization process itself. Rather than treating all gradient dimensions equally, it recognizes that:
54
+
-**Gradient Frequencies Matter:** The Fourier transform of gradient updates reveals patterns related to what the model is currently learning.
55
+
-**Different Parameters Process Different Bands:** Just as our ears have frequency-specific receptors, different parts of the model specialize in different acoustic frequencies.
56
+
-**Temporal Structure in Learning:** Speech learning happens in stages - first basic acoustics, then phonetic patterns, then linguistic structures.
57
+
58
+
By applying different momentum factors to different frequency bands in parameter space, we're essentially giving the optimizer information about the audio domain that it wouldn't otherwise have.
59
+
60
+
61
+
```python
62
+
63
+
classMaxFactor(torch.optim.Optimizer):
64
+
"""
65
+
MaxFactor optimizer that combines adaptive learning rates with factorized second moments.
66
+
67
+
Args:
68
+
params (iterable): Iterable of parameters to optimize
69
+
lr (float, optional): Maximum learning rate (default: 0.01)
70
+
beta2_decay (float, optional): Decay exponent for second moments (default: -0.8)
71
+
eps (tuple, optional): Small constants for numerical stability (default: (None, 1e-3))
72
+
d (float, optional): Scaling factor for updates (default: 1.0)
0 commit comments