Skip to content

Commit 7ee6089

Browse files
yeqinglitensorflower-gardener
authored andcommitted
Adds the offset argument to the supported learning rate.
PiperOrigin-RevId: 381301573
1 parent 169e405 commit 7ee6089

File tree

4 files changed

+116
-4
lines changed

4 files changed

+116
-4
lines changed

official/modeling/optimization/configs/learning_rate_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ class StepwiseLrConfig(base_config.Config):
5656
values[0] [boundaries[0], boundaries[1]] -> values[1]
5757
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
5858
end] -> values[n+1] Defaults to None.
59+
offset: An int. The offset applied to steps. Defaults to 0.
5960
"""
6061
name: str = 'PiecewiseConstantDecay'
6162
boundaries: Optional[List[int]] = None
6263
values: Optional[List[float]] = None
64+
offset: int = 0
6365

6466

6567
@dataclasses.dataclass
@@ -76,12 +78,14 @@ class ExponentialLrConfig(base_config.Config):
7678
decay_rate: A float. Defaults to None.
7779
staircase: A boolean, if true, learning rate is decreased at discreate
7880
intervals. Defaults to False.
81+
offset: An int. The offset applied to steps. Defaults to 0.
7982
"""
8083
name: str = 'ExponentialDecay'
8184
initial_learning_rate: Optional[float] = None
8285
decay_steps: Optional[int] = None
8386
decay_rate: Optional[float] = None
8487
staircase: Optional[bool] = None
88+
offset: int = 0
8589

8690

8791
@dataclasses.dataclass
@@ -99,13 +103,15 @@ class PolynomialLrConfig(base_config.Config):
99103
power: A float. The power of the polynomial. Defaults to linear, 1.0.
100104
cycle: A boolean, whether or not it should cycle beyond decay_steps.
101105
Defaults to False.
106+
offset: An int. The offset applied to steps. Defaults to 0.
102107
"""
103108
name: str = 'PolynomialDecay'
104109
initial_learning_rate: Optional[float] = None
105110
decay_steps: Optional[int] = None
106111
end_learning_rate: float = 0.0001
107112
power: float = 1.0
108113
cycle: bool = False
114+
offset: int = 0
109115

110116

111117
@dataclasses.dataclass
@@ -122,11 +128,13 @@ class CosineLrConfig(base_config.Config):
122128
to None.
123129
alpha: A float. Minimum learning rate value as a fraction of
124130
initial_learning_rate.
131+
offset: An int. The offset applied to steps. Defaults to 0.
125132
"""
126133
name: str = 'CosineDecay'
127134
initial_learning_rate: Optional[float] = None
128135
decay_steps: Optional[int] = None
129136
alpha: float = 0.0
137+
offset: int = 0
130138

131139

132140
@dataclasses.dataclass

official/modeling/optimization/lr_schedule.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,75 @@
1919
import tensorflow as tf
2020

2121

22+
def _make_offset_wrapper(new_class_name: str, base_lr_class):
23+
"""Generates a offset wrapper of learning rate schedule.
24+
25+
It will returns a subclass of the the `base_lr_class`, the subclass takes an
26+
`offset` argument in the constructor. When the new class instance is called,
27+
the behavior is:
28+
new_class_object(step) = base_lr_class_object(step - offset)
29+
30+
Example:
31+
CosineDecayWithOffset = _make_offset_wrapper(
32+
'CosineDecayWithOffset', tf.keras.experimental.CosineDecay)
33+
# Use the lr:
34+
lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
35+
decay_steps=1000)
36+
lr(101) # equals to tf.keras.experimental.CosineDecay(...)(101-100)
37+
38+
Args:
39+
new_class_name: the name of the new class.
40+
base_lr_class: the base learning rate schedule class. Should be subclass of
41+
tf.keras.optimizers.schedules.LearningRateSchedule
42+
43+
Returns:
44+
A new class (subclass of the base_lr_class) that can take an offset.
45+
"""
46+
assert issubclass(base_lr_class,
47+
tf.keras.optimizers.schedules.LearningRateSchedule), (
48+
"base_lr_class should be subclass of keras "
49+
f"LearningRateSchedule, got {base_lr_class}")
50+
51+
# pylint: disable=protected-access,pointless-statement
52+
def offset_learning_rate_init(self, offset=0, **kwargs):
53+
"""Construct learning rate schedule object.
54+
55+
When this object is called, its behavior is
56+
self.__call__(step) == base_lr_class.__call__(step - offset)
57+
Args:
58+
self: this object.
59+
offset: The offset when computing the learning rate schedule.
60+
**kwargs: Pass through to base learning rate class constructor.
61+
"""
62+
base_lr_class.__init__(self, **kwargs)
63+
self._offset = offset
64+
65+
def offset_learning_rate_call(self, step):
66+
step = tf.cast(step - self._offset, tf.float32)
67+
return base_lr_class.__call__(self, step)
68+
69+
# pylint: enable=protected-access,pointless-statement
70+
71+
return type(
72+
new_class_name, (base_lr_class,), {
73+
"base_lr_class": base_lr_class,
74+
"__init__": offset_learning_rate_init,
75+
"__call__": offset_learning_rate_call
76+
})
77+
78+
79+
PiecewiseConstantDecayWithOffset = _make_offset_wrapper(
80+
"PiecewiseConstantDecayWithOffset",
81+
tf.keras.optimizers.schedules.PiecewiseConstantDecay)
82+
PolynomialDecayWithOffset = _make_offset_wrapper(
83+
"PolynomialDecayWithOffset", tf.keras.optimizers.schedules.PolynomialDecay)
84+
ExponentialDecayWithOffset = _make_offset_wrapper(
85+
"ExponentialDecayWithOffset",
86+
tf.keras.optimizers.schedules.ExponentialDecay)
87+
CosineDecayWithOffset = _make_offset_wrapper("CosineDecayWithOffset",
88+
tf.keras.experimental.CosineDecay)
89+
90+
2291
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
2392
"""Linear warmup schedule."""
2493

official/modeling/optimization/lr_schedule_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,40 @@ def test_power_linear_lr_schedule(self, init_lr, power, linear_decay_fraction,
7070
self.assertAlmostEqual(lr(step).numpy(), value)
7171

7272

73+
class OffsetLearningRateTest(tf.test.TestCase, parameterized.TestCase):
74+
75+
@parameterized.parameters(
76+
dict(class_name=lr_schedule.PiecewiseConstantDecayWithOffset),
77+
dict(class_name=lr_schedule.PolynomialDecayWithOffset),
78+
dict(class_name=lr_schedule.ExponentialDecayWithOffset),
79+
dict(class_name=lr_schedule.CosineDecayWithOffset),
80+
)
81+
def test_generated_docstring(self, class_name):
82+
self.assertNotEmpty(class_name.__init__.__doc__)
83+
84+
@parameterized.parameters(
85+
dict(
86+
class_name=lr_schedule.PiecewiseConstantDecayWithOffset,
87+
kwarg=dict(boundaries=[50, 80], values=[1.0, 0.5, 0.1])),
88+
dict(
89+
class_name=lr_schedule.PolynomialDecayWithOffset,
90+
kwarg=dict(initial_learning_rate=1.0, decay_steps=100)),
91+
dict(
92+
class_name=lr_schedule.ExponentialDecayWithOffset,
93+
kwarg=dict(
94+
initial_learning_rate=1.0, decay_steps=100, decay_rate=0.5)),
95+
dict(
96+
class_name=lr_schedule.CosineDecayWithOffset,
97+
kwarg=dict(initial_learning_rate=1.0, decay_steps=100)),
98+
)
99+
def test_offset(self, class_name, kwarg):
100+
offset = 10
101+
offset_lr = class_name(offset=offset, **kwarg)
102+
base_lr = class_name.base_lr_class(**kwarg)
103+
self.assertIsInstance(offset_lr, class_name)
104+
for step in range(10, 101, 10):
105+
self.assertEqual(offset_lr(step), base_lr(step - offset))
106+
107+
73108
if __name__ == '__main__':
74109
tf.test.main()

official/modeling/optimization/optimizer_factory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
}
3939

4040
LR_CLS = {
41-
'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay,
42-
'polynomial': tf.keras.optimizers.schedules.PolynomialDecay,
43-
'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
44-
'cosine': tf.keras.experimental.CosineDecay,
41+
'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
42+
'polynomial': lr_schedule.PolynomialDecayWithOffset,
43+
'exponential': lr_schedule.ExponentialDecayWithOffset,
44+
'cosine': lr_schedule.CosineDecayWithOffset,
4545
'power': lr_schedule.DirectPowerDecay,
4646
'power_linear': lr_schedule.PowerAndLinearDecay,
4747
'power_with_offset': lr_schedule.PowerDecayWithOffset,

0 commit comments

Comments
 (0)