-
Notifications
You must be signed in to change notification settings - Fork 749
Description
Often, certain model parameters—such as normalization layers and input/output embeddings—are configured to exclude weight decay during optimization. This is particularly important for input embeddings, since not all tokens appear in every batch. When a token is relatively rare, its embedding may be pushed toward zero by the optimizer due to weight decay, even though it has not received a gradient update.
Recent LLM architectures such as Trinity and OLMo3 follow this practice, so it could be a useful feature to support.
From an implementation perspective, this could be done by introducing a configuration parameter (e.g., no_weight_decay_keywords or no_weight_decay_params) in the optimizer component:
torchtitan/torchtitan/components/optimizer.py
Lines 64 to 88 in fa8e6cc
| class Config(Configurable.Config): | |
| name: str = "AdamW" | |
| """Optimizer to use""" | |
| lr: float = 8e-4 | |
| """Learning rate to use""" | |
| beta1: float = 0.9 | |
| beta2: float = 0.95 | |
| """Exponential moving average hyperparameters to use""" | |
| eps: float = 1e-8 | |
| """Epsilon value to use""" | |
| weight_decay: float = 0.1 | |
| """Weight decay to use""" | |
| implementation: Literal["for-loop", "foreach", "fused"] = "fused" | |
| """ | |
| Specify which optimizer implementation to use: | |
| - 'fused': Use fused implementation (CUDA only) for best performance. | |
| - 'foreach': Use some horizontal fusion of tensors for better performance. | |
| - 'for-loop': Use the default implementation for the optimizer (slowest). | |
| - more info: https://pytorch.org/docs/stable/optim.html | |
| """ |
One possible approach would be to exclude parameters based on their names. This could be implemented in different ways, for example:
- Hard match: exclude a parameter only if its name exactly matches one of the entries.
- Soft match: exclude a parameter if its name contains one of the provided keywords (e.g.,
any(kw in param_name for kw in no_weight_decay_keywords)).
If this feature would be useful, I’d be happy to submit a PR. Is there a preferred approach between these matching strategies?