Skip to content

Commit 3e19cc5

Browse files
ds-hwangcopybara-github
authored andcommitted
Add EmaDecaySchedule.
Mimic EMA decay schedule of tf.train.ExponentialMovingAverage [1]. [1] https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage In addition, add `pass_absolute_step` to CycleSchedule. PiperOrigin-RevId: 490624052
1 parent 523ff4a commit 3e19cc5

File tree

2 files changed

+83
-18
lines changed

2 files changed

+83
-18
lines changed

lingvo/core/schedule.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,26 @@ def Value(self, step=None):
705705
return self.combine.Value(step)
706706

707707

708+
class EmaDecaySchedule(BaseSchedule):
709+
"""Mimic EMA decay schedule of tf.train.ExponentialMovingAverage.
710+
711+
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
712+
"""
713+
714+
@classmethod
715+
def Params(cls):
716+
p = super().Params()
717+
p.Define('ema_decay', 0.9999, 'The EMA decay parameter.')
718+
return p
719+
720+
def Value(self, step=None):
721+
p = self.params
722+
x = tf.cast(self.GetStep(step), dtype=p.dtype)
723+
# https://github.com/tensorflow/tensorflow/blob/v2.11.0/tensorflow/python/training/moving_averages.py#L578-L582
724+
warmup = (1.0 + x) / (10.0 + x)
725+
return tf.minimum(warmup, p.ema_decay)
726+
727+
708728
class DevBasedSchedule(BaseSchedule):
709729
"""Decay triggered by lack of improvement on the dev set.
710730
@@ -928,10 +948,10 @@ class CycleSchedule(BaseSchedule):
928948
@classmethod
929949
def Params(cls):
930950
p = super().Params()
931-
p.Define(
932-
'schedules', None, 'A list of sub-schedules. Unlike PiecewiseSchedule, '
933-
'the absolute step is passed to the sub-schedule.')
951+
p.Define('schedules', None, 'A list of sub-schedules.')
934952
p.Define('steps', None, 'The number of steps to run each sub-schedule.')
953+
p.Define('pass_absolute_step', True,
954+
'Whether to pass the absolute step to the sub-schedule.')
935955
return p
936956

937957
def __init__(self, params):
@@ -949,9 +969,11 @@ def __init__(self, params):
949969

950970
def Value(self, step=None):
951971
values = []
972+
step = self.GetStep(step)
973+
relative_step = tf.math.mod(step, self._period)
974+
schedule_step = step if self.params.pass_absolute_step else relative_step
952975
for schedule in self.schedules:
953-
values.append(schedule.Value(step))
954-
relative_step = tf.math.mod(self.GetStep(step), self._period)
976+
values.append(schedule.Value(schedule_step))
955977
return py_utils.PiecewiseConstant(relative_step, self._boundaries, values,
956978
values[0].dtype)
957979

lingvo/core/schedule_test.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,25 @@ def testLinearRampupSqrtDecayByBatchSizeAndReplicasSchedule(self):
642642
with py_utils.GlobalStepContext(1599):
643643
self.assertAllClose(lrs.Value().eval(), 0.025)
644644

645+
def testEmaDecaySchedule(self):
646+
p = schedule.EmaDecaySchedule.Params().Set(ema_decay=0.9999)
647+
with self.session():
648+
lrs = p.Instantiate()
649+
pts = []
650+
for step in range(0, 130_000, 25_000):
651+
with py_utils.GlobalStepContext(step):
652+
pts.append([step, lrs.Value().eval()])
653+
self.assertAllClose(
654+
pts,
655+
[
656+
[0, 0.1], # (1.0 + x) / (10.0 + x)
657+
[25000, 0.99964017], # (1.0 + x) / (10.0 + x)
658+
[50000, 0.99982005], # (1.0 + x) / (10.0 + x)
659+
[75000, 0.99988], # (1.0 + x) / (10.0 + x)
660+
[100000, 0.9999], # 0.9999
661+
[125000, 0.9999], # 0.9999
662+
])
663+
645664
def testDevBasedSchedule(self):
646665
logdir = tf.test.get_temp_dir()
647666
tf.io.gfile.mkdir(os.path.join(logdir, 'eval_dev'))
@@ -838,25 +857,49 @@ def testPiecewiseSchedule(self):
838857
[60000, 0.0], # pi.
839858
])
840859

841-
def testCycleSchedule(self):
842-
p0 = schedule.LinearSchedule.Params().Set(start=(0, 0.), limit=(1000, 1.))
860+
@parameterized.named_parameters(
861+
{
862+
'testcase_name':
863+
'RelCycle',
864+
'pass_absolute_step':
865+
False,
866+
'expected_step_value': [
867+
[0, 0.0],
868+
[2, 2 / 10.],
869+
[4, 4 / 10.],
870+
[6, 5.0],
871+
[8, 5.0],
872+
[10, 0.0], # rel_step = 10%10 = 0
873+
[12, 2 / 10.],
874+
]
875+
},
876+
{
877+
'testcase_name':
878+
'AbsCycle',
879+
'pass_absolute_step':
880+
True,
881+
'expected_step_value': [
882+
[0, 0.0],
883+
[2, 2 / 10.],
884+
[4, 4 / 10.],
885+
[6, 5.0],
886+
[8, 5.0],
887+
[10, 1.0], # abs_step, LinearSchedule is stuck at limit.
888+
[12, 1.0],
889+
]
890+
})
891+
def testCycleSchedule(self, pass_absolute_step, expected_step_value):
892+
p0 = schedule.LinearSchedule.Params().Set(start=(0, 0.), limit=(10, 1.))
843893
p1 = schedule.Constant.Params().Set(value=5.0)
844-
p = schedule.CycleSchedule.Params().Set(schedules=[p0, p1], steps=[4, 1])
894+
p = schedule.CycleSchedule.Params().Set(
895+
schedules=[p0, p1], steps=[6, 4], pass_absolute_step=pass_absolute_step)
845896
with self.session():
846897
lrs = p.Instantiate()
847898
pts = []
848-
for step in [0, 1, 4, 5, 998, 999, 1000]:
899+
for step in range(0, 13, 2):
849900
with py_utils.GlobalStepContext(step):
850901
pts.append([step, lrs.Value().eval()])
851-
self.assertAllClose(pts, [
852-
[0, 0.0],
853-
[1, 1.0 / 1000.0],
854-
[4, 5.0],
855-
[5, 5.0 / 1000.0],
856-
[998, 998.0 / 1000.0],
857-
[999, 5.0],
858-
[1000, 1.0],
859-
])
902+
self.assertAllClose(pts, expected_step_value)
860903

861904
def testAnnealingSchedule(self):
862905
p = schedule.AnnealingSchedule.Params().Set(

0 commit comments

Comments
 (0)