Skip to content

Commit 5e4b7fd

Browse files
ds-hwangcopybara-github
authored andcommitted
Add cyclical_step flag to TransformerSchedule.
Cyclic LR escape the current local minimum to explore other minima. [1] [1] https://arxiv.org/abs/1704.00109 PiperOrigin-RevId: 488968190
1 parent 2aef09d commit 5e4b7fd

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

lingvo/core/schedule.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,19 @@ def Params(cls):
319319
'decay_end-th step.')
320320
p.Define('decay_factor', -0.5, 'Decay factor after warmup.')
321321
p.Define('start_step', 0, 'Translate the function left in the step axis.')
322+
p.Define(
323+
'cyclical_step', None, 'Int, if set, at the step, the cycle restarts. '
324+
'Cyclic LR escapes a bad local minimum. arxiv.org/abs/1704.00109')
322325
return p
323326

324327
def Value(self, step=None):
325328
"""Returns the current learning rate decay."""
326329
p = self.params
327-
current_step = tf.cast(self.GetStep(step), tf.float32)
330+
current_step = self.GetStep(step)
331+
if p.cyclical_step is not None:
332+
assert isinstance(p.cyclical_step, int)
333+
current_step = tf.math.mod(current_step, p.cyclical_step)
334+
current_step = tf.cast(current_step, tf.float32)
328335
start_step = tf.cast(p.start_step, tf.float32)
329336
warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32)
330337
if p.decay_end is not None:

lingvo/core/schedule_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,22 @@ def testTransformerScheduleNoWarmUp(self):
307307
self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval())
308308
self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval())
309309

310+
def testTransformerScheduleCycle(self):
311+
ref_p = schedule.TransformerSchedule.Params().Set(warmup_steps=0)
312+
cycle = 3000
313+
cyc_p = ref_p.Copy().Set(cyclical_step=cycle)
314+
with self.session():
315+
ref = ref_p.Instantiate()
316+
cyc = cyc_p.Instantiate()
317+
ref_pts = []
318+
cyc_pts = []
319+
for step in range(0, 1000, 10_000):
320+
with py_utils.GlobalStepContext(step % cycle):
321+
ref_pts.append(ref.Value().eval())
322+
with py_utils.GlobalStepContext(step):
323+
cyc_pts.append(cyc.Value().eval())
324+
self.assertAllClose(ref_pts, cyc_pts)
325+
310326
def testTransformerMLPerfSchedule(self):
311327
params = schedule.TransformerMLPerfSchedule.Params().Set(
312328
warmup_steps=4000, warmup_init_fraction=.3, model_dim=512)

0 commit comments

Comments
 (0)