Skip to content

Commit 43b5203

Browse files
Publishing internal doc Optimizer and Learning Rate Scheduler(Optimization.md) into github
PiperOrigin-RevId: 665720689
1 parent 12bc293 commit 43b5203

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Optimizer and Learning Rate Scheduler
2+
3+
4+
5+
This page describes the
6+
[optimization package](https://github.com/tensorflow/models/tree/master/official/modeling/optimization/)
7+
for Tensorflow Official Models (TFM) which includes optimizers, and learning
8+
rate schedulers.
9+
10+
## Building Optimizer and LR Scheduler
11+
12+
We use an Optimizer factory class to manage optimizer and learning rate
13+
creation. Optimizer factory takes a config as an input, and it has member
14+
functions that are used to build optimizer and learning rate schedule. To create
15+
an optimizer and a LR schedule through OptimizerFactory, you need to do the
16+
following:
17+
18+
1. Define optimization config, this includes optimizer, and learning rate
19+
schedule.
20+
2. Initialize the OptimizerFactory instance using the optimization config.
21+
3. Build the learning rate, and the optimizer using the class member functions.
22+
23+
The following is an example for creating an SGD optimizer with stepwise LR
24+
scheduler with linear warmup:
25+
26+
```python
27+
from third_party.tensorflow_models.official.modeling import optimization
28+
params = {'optimizer': { 'type': 'sgd',
29+
'sgd': {'momentum': 0.9}},
30+
'learning_rate': {'type': 'stepwise',
31+
'stepwise': {
32+
'boundaries': [10000, 20000],
33+
'values': [0.1, 0.01, 0.001]}},
34+
'warmup': {'type': 'linear',
35+
'linear': {'warmup_steps': 500,
36+
'warmup_learning_rate': 0.01}}}
37+
# Defines optimization config from a dictionary.
38+
opt_config = optimization.OptimizationConfig(params)
39+
# Initializes an optimization factory from optimization config.
40+
opt_factory = optimization.OptimizerFactory(opt_config)
41+
# Builds the desired learning rate scheduling instance.
42+
lr = opt_factory.build_learning_rate()
43+
# Builds the optimizer instance with the desired learning rate schedule.
44+
optimizer = opt_factory.build_optimizer(lr)
45+
```
46+
47+
To initialize an OptimizerFactory, `optimizer` and `learning_rate` fields must
48+
be defined, while `warmup` is an optional field. The field `type` is used to
49+
define the type of each optimization component. The set of available types is
50+
explained in details in the following sections.
51+
52+
In the following sections, we explain how to create different optimizers,
53+
learning rate, and warmup schedulers. We also explain how to add new optimizers,
54+
or learning rate schedulers.
55+
56+
## Optimizers
57+
58+
The list of supported optimizers can be found
59+
[here](https://github.com/tensorflow/models/blob/7f239d8ec19b5c2d44e0d5aa2a09dbea0da6d737/official/modeling/optimization/optimizer_factory.py#L43).
60+
61+
```python
62+
OPTIMIZERS_CLS = {
63+
'sgd': tf.keras.optimizers.SGD,
64+
'adam': tf.keras.optimizers.Adam,
65+
'adamw': nlp_optimization.AdamWeightDecay,
66+
'lamb': tfa_optimizers.LAMB,
67+
'rmsprop': tf.keras.optimizers.RMSprop
68+
}
69+
```
70+
71+
You can specify the type of optimizer to be one of the above using
72+
[oneof](https://github.com/tensorflow/models/blob/master/official/modeling/hyperparams/oneof.py)
73+
config. The available config fields can be found
74+
[here](https://github.com/tensorflow/models/blob/master/official/modeling/optimization/configs/optimizer_config.py).
75+
76+
All optimizers support gradient clipping methods: clip by value, clip by norm,
77+
clip by global norm. To speicify which method to use, you need to specify the
78+
appropiate field list
79+
[here](https://github.com/tensorflow/models/blob/7f239d8ec19b5c2d44e0d5aa2a09dbea0da6d737/official/modeling/optimization/configs/optimizer_config.py#L38).
80+
81+
### Example
82+
83+
We will specify an rmsprop optimizer with discounting factor (rho) of 0.9, and
84+
global norm gradient clipping of 10.0. Below is the config to be used.
85+
86+
```python
87+
params = {'optimizer': { 'type': 'rmsprop',
88+
'rmsprop': {'rho': 0.9,
89+
'global_clipnorm': 10.0}}}
90+
```
91+
92+
### Adding a New Optimizer
93+
94+
To add a new optimizer, you need to do the following:
95+
96+
1. Create a
97+
[custom](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#creating_a_custom_optimizer_2)
98+
of tf.keras.optimizers.Optimizer.
99+
2. Add the required config fields under
100+
[optimization/configs/optimizer_config.py](https://github.com/tensorflow/models/blob/master/official/modeling/optimization/configs/optimizer_config.py).
101+
3. Add the optimizer class to the list of available optimizer classes in
102+
[optimizer_factor](https://github.com/tensorflow/models/blob/master/official/modeling/optimization/optimizer_factory.py)
103+
104+
## Learning Rate and Warmup Schedules
105+
106+
Learning rate with an optional warmup can be configured by specifying
107+
`learning_rate`, and `warmup` fields in optimization config. `learning_rate` is
108+
a required field, while `warmup` is an optional one. The list of supported
109+
`learning_rate` and `warmup` schedules can be found
110+
[here](https://github.com/tensorflow/models/blob/7f239d8ec19b5c2d44e0d5aa2a09dbea0da6d737/official/modeling/optimization/optimizer_factory.py#L51).
111+
112+
```python
113+
LR_CLS = {
114+
'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay,
115+
'polynomial': tf.keras.optimizers.schedules.PolynomialDecay,
116+
'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
117+
'cosine': tf.keras.experimental.CosineDecay,
118+
'power': lr_schedule.DirectPowerDecay,
119+
}
120+
121+
WARMUP_CLS = {
122+
'linear': lr_schedule.LinearWarmup,
123+
'polynomial': lr_schedule.PolynomialWarmUp
124+
}
125+
```
126+
127+
In addition, a `constant` learning rate can be specified.
128+
129+
## How Learning Rate Works
130+
131+
Learning rate takes `step` as an input, and it returns the learning rate value.
132+
As the training progresses, usually learning rate value decays. Warmup schedule
133+
is often used to stablize the training. Warmup schedule starts from a low
134+
learning rate value, and it gradually increases until it reaches the initial
135+
value for the regular learning rate decay schedule. We combine `learning_rate`
136+
(lr) with `warmup` (warmup) schedules as follows
137+
138+
* Steps [0, warmup_steps): `learning_rate = warmup(step)`
139+
* Steps [warmup_steps, train_steps): `learning_rate = lr(step)`
140+
* We designed the warmup schedule such that final warmup learning rate is
141+
inferred from the learning rate schedule (i.e.
142+
`learning_rate(warmup_steps) = warmup(warmup_steps)`). Note that, warmup
143+
schedule doesn't delay the regular learning rate decay by warmup_steps,
144+
instead it replaces it.
145+
146+
Learning rate value is logged every
147+
[summary_interval](https://github.com/tensorflow/models/blob/7f239d8ec19b5c2d44e0d5aa2a09dbea0da6d737/official/core/config_definitions.py#L181).
148+
If warmup_steps are less that the `summary_interval`, you won't be able to see
149+
warmup values.
150+
151+
### Example
152+
153+
We want to specify a cosine learning rate decay with decay_steps of 20000, with
154+
a linear warmup schedule for the first 500 steps.
155+
156+
```python
157+
params = {'learning_rate': {'type': 'cosine',
158+
'cosine': {'decay_steps': 20000}},
159+
'warmup': {'type': 'linear',
160+
'linear': {'warmup_steps': 500}}}
161+
```
162+
163+
## Customizing Optimizer Inside Task
164+
165+
Optimizer and learning rate are created inside the
166+
[task](https://github.com/tensorflow/models/blob/7f239d8ec19b5c2d44e0d5aa2a09dbea0da6d737/official/core/base_task.py#L73).
167+
If different optimizers/learning rate schedulers are needed, they can be defined
168+
by overriding the class method.
169+
170+
## Important Factors To Consider
171+
172+
* Batch size: Changing batch size usually requires scaling learning rate
173+
values, and number of training steps. Make sure that you change appropriate
174+
values as batch size changes.
175+
* Train steps: Train steps is highly correlated with fields such as
176+
`decay_steps` for cosine learning rate decay. Changing one without changing
177+
the other might result in undesired behavior.

0 commit comments

Comments
 (0)