Skip to content

Commit 8ca46ae

Browse files
ds-hwangcopybara-github
authored andcommitted
Implement EMA schedule
tf.train.ExponentialMovingAverage [1] accepts decay=tf.Variable. To have fine-grained control over the value of the decay parameter during training, pass a scalar tf.Variable as the decay value to the constructor, and update the variable as needed. [1] https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage For example, we can easily have pure weight average over t, as setting ema_schedule to t**-1. It give average theta = mean(theta_i, (0, t)). PiperOrigin-RevId: 490081739
1 parent 4c0f0c2 commit 8ca46ae

File tree

5 files changed

+167
-27
lines changed

5 files changed

+167
-27
lines changed

lingvo/core/base_model.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import re
21-
from typing import Dict, Union
21+
from typing import Dict, Optional, Union
2222

2323
import lingvo.compat as tf
2424
from lingvo.core import base_input_generator
@@ -64,6 +64,15 @@ class DecodeEmailOptions:
6464
global_step: int
6565

6666

67+
@dataclasses.dataclass(frozen=True)
68+
class ExecutorEma:
69+
"""EMA related instances which an executor prepares."""
70+
# EMA object.
71+
ema: Optional[tf.train.ExponentialMovingAverage] = None
72+
# ema_decay variable.
73+
ema_decay: Optional[tf.Variable] = None
74+
75+
6776
def _VariablesForEMA(params, model_var_list):
6877
"""Gets a list of variables that need to apply exponential moving average."""
6978
# Use variable reference since variable is not hashable in eager mode.
@@ -168,6 +177,7 @@ def Params(cls):
168177
tp.Define(
169178
'ema_decay_moving_vars', None,
170179
'If True, include variables from collection "moving_vars" in ema.')
180+
tp.Define('ema_schedule', None, 'EMA decay schedule over global_step.')
171181
tp.Define(
172182
'init_from_checkpoint_rules', {},
173183
'If not None, a dictionary with keys corresponding to a checkpoint '
@@ -436,6 +446,9 @@ def __init__(self, params):
436446
self.CreateChildren('learners', tp.learner)
437447
else:
438448
self.CreateChildren('learners', [tp.learner])
449+
450+
if tp.ema_schedule:
451+
self.CreateChild('ema_schedule', tp.ema_schedule)
439452
self._UpdateVnConfig()
440453

441454
if (tp and tp.pruning_hparams_dict and
@@ -857,9 +870,22 @@ def ApplyExponentialMovingAverage(self):
857870
self._graphs_applied_ema.add(graph)
858871

859872
tf.logging.info('ApplyExponentialMovingAverage on %s', self)
873+
pre_op = tf.no_op()
874+
875+
# Update EMA decay.
876+
tp = self.params.train
877+
if tp.ema_schedule:
878+
assert isinstance(self.parent, BaseModel)
879+
ema_decay_var = self.parent.ema_decay
880+
assert ema_decay_var is not None
881+
ema_decay = self.ema_schedule.Value(step=self._global_step_var)
882+
ema_decay = tf.minimum(ema_decay, 1.0)
883+
pre_op = ema_decay_var.assign(ema_decay, read_value=False)
884+
860885
# Use empty name here so no prefix is added to the EMA variable names.
861886
scoped_creator = py_utils.GetLingvoVariableCreator('', '')
862-
return scoped_creator(self._ApplyEMA, ema=ema)
887+
with tf.control_dependencies([pre_op]):
888+
return scoped_creator(self._ApplyEMA, ema=ema)
863889

864890
def _ApplyEMA(self, ema):
865891
all_vars = _VariablesForEMA(self.params, self.vars.Flatten())
@@ -1134,6 +1160,7 @@ def Params(cls):
11341160
'ema_decay_moving_vars', None,
11351161
'If True, include variables from collection "moving_vars" in ema. '
11361162
'Must be set consistent across all tasks.')
1163+
tp.Define('ema_schedule', None, 'EMA decay schedule over global_step.')
11371164
tp.Define('init_from_checkpoint_rules', {},
11381165
'See BaseTask documentation for details.')
11391166
tp.Define('init_from_checkpoint_override', '',
@@ -1163,7 +1190,7 @@ def Params(cls):
11631190
'checkpoints. Currently only support custom saver.')
11641191
return p
11651192

1166-
def __init__(self, params, executor_ema=None):
1193+
def __init__(self, params, executor_ema=ExecutorEma()):
11671194
"""Initializes this Model."""
11681195
assert issubclass(params.cls, BaseModel)
11691196
super().__init__(params)
@@ -1172,21 +1199,23 @@ def __init__(self, params, executor_ema=None):
11721199
self._global_step_var = py_utils.GetOrCreateGlobalStepVar()
11731200

11741201
tp = self.params.train
1175-
if tp.ema_decay > 0:
1202+
if tp.ema_decay > 0 or tp.ema_schedule:
11761203
assert tp.ema_decay < 1.0
1177-
assert self.cluster.is_executor_tpu == (executor_ema is not None)
1178-
if executor_ema is not None:
1204+
assert self.cluster.is_executor_tpu == (executor_ema.ema is not None)
1205+
if executor_ema.ema is not None:
11791206
# Use the EMA for executor training if set.
1180-
self._ema = executor_ema
1207+
self._ema, self._ema_decay = executor_ema.ema, executor_ema.ema_decay
11811208
else:
1182-
self._ema = py_utils.CreateEMAForModel(self.params, self.global_step)
1209+
self._ema_decay = py_utils.CreateEMADecayVar(self.params)
1210+
self._ema = py_utils.CreateEMAForModel(self.params, self.global_step,
1211+
self._ema_decay)
11831212
else:
11841213
# Evaler/Decoder may disable EMA while ExecutorTpu uses EMA. executor_ema
11851214
# depends on the trainer task params, but Evaler/Decoder may have
11861215
# different task params (e.g. ema_decay=0). See model_registry.py
11871216
if not self.do_eval:
1188-
assert not executor_ema
1189-
self._ema = None
1217+
assert executor_ema.ema is None
1218+
self._ema = self._ema_decay = None
11901219
self._ema_variables_dict = {}
11911220

11921221
@property
@@ -1197,6 +1226,10 @@ def global_step(self):
11971226
def variables_for_ema(self):
11981227
return _VariablesForEMA(self.params, self.vars.Flatten())
11991228

1229+
@property
1230+
def ema_decay(self):
1231+
return self._ema_decay
1232+
12001233
def _MakeEMAVariablesDict(self):
12011234
if self.ema:
12021235
res = {}
@@ -1364,6 +1397,7 @@ def CopyTaskParams(cls, task_params, p):
13641397
tp.checkpoint_finite_check = p.task.train.checkpoint_finite_check
13651398
tp.ema_decay = p.task.train.ema_decay
13661399
tp.ema_decay_moving_vars = p.task.train.ema_decay_moving_vars
1400+
tp.ema_schedule = p.task.train.ema_schedule
13671401

13681402
def __init__(self, params, **kwargs):
13691403
assert issubclass(params.cls, SingleTaskModel)
@@ -1498,10 +1532,11 @@ def __init__(self, params, **kwargs):
14981532
p.task_schedule = task_scheduler.ConstantScheduler.Params()
14991533
p.task_schedule.task_probs = sorted(list(p.task_probs.IterParams()))
15001534

1501-
if p.train.ema_decay > 0:
1535+
tp = p.train
1536+
if tp.ema_decay > 0 or tp.ema_schedule:
15021537
for task_name, task_params in sorted_task_params:
15031538
for field in ['ema_decay', 'ema_decay_moving_vars']:
1504-
if task_params.train.Get(field) != p.train.Get(field):
1539+
if task_params.train.Get(field) != tp.Get(field):
15051540
raise ValueError('Params did not match for field %s in task %s' %
15061541
(field, task_name))
15071542

lingvo/core/ema_test.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lingvo.core import cluster_factory
2424
from lingvo.core import layers
2525
from lingvo.core import py_utils
26+
from lingvo.core import schedule
2627
from lingvo.core import test_utils
2728
import mock
2829
import numpy as np
@@ -146,8 +147,11 @@ def testBatchNormLayer(self):
146147
graph=tf.Graph()) as sess, cluster_factory.ForTestingWorker(
147148
job='executor_tpu', do_eval=True), mock.patch(
148149
'lingvo.core.py_utils.use_tpu', return_value=True):
149-
executor_ema = py_utils.CreateEMAForModel(
150-
p, py_utils.GetOrCreateGlobalStepVar())
150+
ema_decay_var = None
151+
ema_var = py_utils.CreateEMAForModel(p,
152+
py_utils.GetOrCreateGlobalStepVar(),
153+
ema_decay_var)
154+
executor_ema = base_model.ExecutorEma(ema_var, ema_decay_var)
151155
model = p.Instantiate(executor_ema=executor_ema)
152156
self.assertIsNotNone(model.ema)
153157
task = model._task
@@ -166,5 +170,81 @@ def testBatchNormLayer(self):
166170
self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
167171
self.evaluate([beta, beta_ema, mean, mean_ema]))
168172

173+
def testEmaSchedule(self):
174+
task = self.TestParams(layers.BatchNormLayer.Params().Set(dim=1))
175+
task.train.ema_decay = 0
176+
# Note: EMA = decay * EMA + (1 - decay) * var
177+
ema_off = 1.0 # ema keeps constant.
178+
ema_is_var = 0.0 # ema copys var value.
179+
task.train.ema_schedule = schedule.PiecewiseConstantSchedule.Params().Set(
180+
boundaries=[99, 199], values=[ema_off, 0.9, ema_is_var])
181+
task.train.ema_decay_moving_vars = True
182+
p = base_model.SingleTaskModel.Params(task)
183+
model = p.Instantiate()
184+
self.assertIsNotNone(model.ema)
185+
self.assertIsNotNone(model.ema_decay)
186+
task = model._task
187+
188+
layer = task.encoder
189+
self.assertLen(layer.vars, 4)
190+
for var in layer.vars.Flatten():
191+
self.assertIsNotNone(model.ema.average(var), msg=var.name)
192+
beta = layer.vars.beta
193+
mean = layer.vars.moving_mean
194+
195+
beta_0 = np.asarray([0.])
196+
mean_0 = np.asarray([0.])
197+
beta_1 = np.asarray([.2])
198+
mean_1 = np.asarray([.03])
199+
beta_1_ema = beta_1 * .1
200+
mean_1_ema = mean_1 * .1
201+
# Check EMA decay schedul in Train.
202+
with self.session():
203+
# Test EMA values.
204+
self.evaluate(tf.global_variables_initializer())
205+
# var is initialized as 0, and EMA assigns the var value.
206+
self.assertAllClose([beta_0, beta_0, mean_0, mean_0],
207+
self.evaluate([
208+
beta,
209+
model.ema.average(beta), mean,
210+
model.ema.average(mean)
211+
]))
212+
213+
# At step=1, ema_decay=1.0 by ema_schedule. EMA update is off.
214+
global_step = 1
215+
self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step))
216+
self.evaluate(tf.assign(beta, beta_1))
217+
self.evaluate(tf.assign(mean, mean_1))
218+
ema_op = task.ApplyExponentialMovingAverage()
219+
self.evaluate(ema_op)
220+
self.assertAllClose([beta_1, beta_0, mean_1, mean_0],
221+
self.evaluate([
222+
beta,
223+
model.ema.average(beta), mean,
224+
model.ema.average(mean)
225+
]))
226+
227+
# At step=100, ema_decay=0.9 by ema_schedule.
228+
global_step = 100
229+
self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step))
230+
self.evaluate(ema_op)
231+
self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
232+
self.evaluate([
233+
beta,
234+
model.ema.average(beta), mean,
235+
model.ema.average(mean)
236+
]))
237+
238+
# At step=200, ema_decay=0.0 by ema_schedule. EMA copies var value.
239+
global_step = 200
240+
self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step))
241+
self.evaluate(ema_op)
242+
self.assertAllClose([beta_1, beta_1, mean_1, mean_1],
243+
self.evaluate([
244+
beta,
245+
model.ema.average(beta), mean,
246+
model.ema.average(mean)
247+
]))
248+
169249
if __name__ == '__main__':
170250
test_utils.main()

lingvo/core/program.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self,
108108
params,
109109
shared_model=None,
110110
trial=base_trial.NoOpTrial(),
111-
ema=None,
111+
executor_ema=base_model.ExecutorEma(),
112112
**kwargs):
113113
self.params = params.Copy()
114114
p = self.params
@@ -121,7 +121,7 @@ def __init__(self,
121121
self._tf_master = kwargs.pop('tf_master', None)
122122
self._write_train_input_stats = p.write_train_input_stats
123123
self._trial = trial
124-
self._ema = ema
124+
self._executor_ema = executor_ema
125125

126126
self._SetProgramDir()
127127
# Initialized on use; access via self._summary_writer property only.
@@ -394,8 +394,8 @@ def _InstantiateTaskModel(self, task_params):
394394
"""
395395
if issubclass(task_params.cls, base_model.MultiTaskSubModel):
396396
return task_params.Instantiate(
397-
shared_model=self._shared_model, executor_ema=self._ema)
398-
return task_params.Instantiate(executor_ema=self._ema)
397+
shared_model=self._shared_model, executor_ema=self._executor_ema)
398+
return task_params.Instantiate(executor_ema=self._executor_ema)
399399

400400
def _OutfeedEnqueue(self, per_example_tensors):
401401
if not per_example_tensors:
@@ -1875,7 +1875,7 @@ def BuildTpuSubgraph(self):
18751875
self._eval_metrics = metrics.TpuEvalMetrics(max_metrics=p.max_metrics)
18761876
with py_utils.OpportunisticVariableReuseScope(True):
18771877
self._train_model = self._train_task_params.Instantiate(
1878-
executor_ema=self._ema)
1878+
executor_ema=self._executor_ema)
18791879
self._train_task = self._train_model.GetTask()
18801880
self._train_task.input.InfeedSetupGraph()
18811881
self._model = self._train_model

lingvo/core/py_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,13 +751,27 @@ def GradientTape(*args, **kwargs):
751751
_GRADIENT_TAPE_STACK.stack.pop()
752752

753753

754-
def CreateEMAForModel(model_params, global_step):
754+
def CreateEMADecayVar(model_params):
755+
"""Creates an EMA decay variable."""
756+
p = model_params
757+
tp = p.train
758+
if tp.ema_schedule:
759+
tf.logging.log_if(tf.logging.WARNING,
760+
f'ema_schedule overrides ema_decay:{tp.ema_decay}.',
761+
tp.ema_decay > 0)
762+
wp = WeightParams(
763+
shape=[], init=WeightInit.Constant(tp.ema_decay), dtype=p.dtype)
764+
return CreateVariable('ema_decay', wp, trainable=False)
765+
return None
766+
767+
768+
def CreateEMAForModel(model_params, global_step, ema_decay):
755769
"""Creates an EMA object for model with param `model_params` if applicable."""
756770
p = model_params
757771

758772
# Check that EMA settings for the model and subtasks match.
759773
def CheckEMA(task_name, task_params):
760-
for field in ['ema_decay', 'ema_decay_moving_vars']:
774+
for field in ['ema_decay', 'ema_decay_moving_vars', 'ema_schedule']:
761775
model_value = p.train.Get(field)
762776
task_value = task_params.train.Get(field)
763777
if task_value != model_value:
@@ -774,9 +788,16 @@ def CheckEMA(task_name, task_params):
774788
# SingleTaskModel.
775789
CheckEMA(p.task.name, p.task)
776790

777-
if p.train.ema_decay > 0:
791+
tp = p.train
792+
if tp.ema_decay > 0 or tp.ema_schedule:
793+
if tp.ema_schedule:
794+
assert isinstance(ema_decay, tf.Variable)
795+
# ema_decay takes all control. Otherwise, global_step affects ema_decay.
796+
global_step = None
797+
else:
798+
ema_decay = p.train.ema_decay
778799
return tf.train.ExponentialMovingAverage(
779-
decay=p.train.ema_decay, num_updates=global_step)
800+
decay=ema_decay, num_updates=global_step)
780801
return None
781802

782803

lingvo/executor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,13 @@ def _WaitTillInit(job=None):
292292

293293
self._checkpoint_to_load = None
294294
with self._cluster:
295-
# Create the ExponentialMovingAverage singleton shared by all programs, if
296-
# applicable.
297-
ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
295+
with tf.container(self._container_id), contextlib.ExitStack() as stack:
296+
if not py_utils.IsEagerMode():
297+
stack.enter_context(self._graph.as_default())
298+
ema_decay_var = py_utils.CreateEMADecayVar(train_cfg)
299+
ema_obj = py_utils.CreateEMAForModel(train_cfg, self._global_step_var,
300+
ema_decay_var)
301+
executor_ema = base_model.ExecutorEma(ema_obj, ema_decay_var)
298302
tf.logging.info('ps_params_dict=%s',
299303
{k: v.ToText() for k, v in ps_params_dict.items()})
300304
for task_string, program_schedule_params in ps_params_dict.items():
@@ -306,7 +310,7 @@ def _WaitTillInit(job=None):
306310
ps = program_schedule_params.Instantiate(
307311
shared_model=shared_model,
308312
trial=self._trial,
309-
ema=ema,
313+
executor_ema=executor_ema,
310314
tf_master=self._tf_master)
311315
self._program_schedule_dict[task_string] = ps
312316
tf.logging.info('program_schedule_params: %s',

0 commit comments

Comments
 (0)