Skip to content

Commit 9150f56

Browse files
ds-hwangcopybara-github
authored andcommitted
Save the current EMA varaibles to the checkpoint, instead of t-1.
Currently, Lingvo saves EMA(t-1) to the t-th checkpoint, as the train_op looks like, def ConstructFPropBPropGraph(self): self.ApplyExponentialMovingAverage() <-- update var_t-1 to ema var (ema_t-1) self._task.FPropDefaultTheta() self._task.BProp() <-- update var_t-1 to var_t ema.apply() after BProp more makes sense, as is updates var_t and then updates ema_t at global_step t. But Lingvo runs ema.apply() first. It's because this method is used for both graph construction and train step. In graph construction, self._task.FPropDefaultTheta() may need EMA variables. e.g. EMBR, EMA teacher, so on. So, ConstructFPropBPropGraph() routine calls ema.apply() first to ensure EMA var. It causes following confusion; * After train_op, var_t and ema_t-1 are saved into the checkpoint at t. * Evaler/Decoder step_ops are very confusing. This issue can be solved by separating EMA variable creation and update, althogh both use same TF API; ema.apply(). Especially it doesn't make sense that graph construction method constructs EMA variables. EMA variables are variables. It more makes sense that variable creation method (i.e. mdl.Instantiate()) creates EMA variables. This CL makes the changes. After that, ConstructFPropBPropGraph() doesn't need to call ema.apply() first, because there is already EMA varaibles. The weird order of train_op is fixed. It's more dramatic for evaler and decoder step_ops. Currently, Evaler's ConstructFPropGraph() and Decoder's ConstructDecodeGraph() have to call ema.apply(), only because of EMA variables creation. Evaler/Decoder must not update EMA variables, which is Trainer's job. Those use EMA variable from the checkpoint or trainer. Now we can remove the weird ema.apply() in evaler/decoder_step_ops (i.e. tech debt). PiperOrigin-RevId: 488507349
1 parent 3bdda3f commit 9150f56

File tree

4 files changed

+84
-79
lines changed

4 files changed

+84
-79
lines changed

lingvo/core/base_layer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,18 +1113,31 @@ def MatchKeys(x, y):
11131113
for k in self.theta.keys():
11141114
assert k in self.vars or k in self._extra_theta
11151115

1116-
def PostTrainingStepUpdate(self):
1116+
def PostTrainingStepUpdate(self) -> tf.Operation:
11171117
"""Returns a TF op which will be invoked at each training step.
11181118
11191119
Subclasses of `BaseLayer` can implement this method. The method should
1120-
return a TF op to be invoked during training after gradients are applied.
1120+
return a TF op to be invoked during training after gradients are applied and
1121+
before EMA is updated.
11211122
"""
11221123
update_ops = [
11231124
child.PostTrainingStepUpdate()
11241125
for child in py_utils.Flatten(self._private_children)
11251126
]
11261127
return tf.group(*update_ops)
11271128

1129+
def PostEmaUpdate(self) -> tf.Operation:
1130+
"""Returns a TF op which will be invoked at each training step.
1131+
1132+
Subclasses of `BaseLayer` can implement this method. The method should
1133+
return a TF op to be invoked during training after EMA is updated.
1134+
"""
1135+
update_ops = [
1136+
child.PostEmaUpdate()
1137+
for child in py_utils.Flatten(self._private_children)
1138+
]
1139+
return tf.group(*update_ops)
1140+
11281141
def _CastToFPropDtype(self, value):
11291142

11301143
def _Cast(x):

lingvo/core/base_model.py

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import collections
1818
import dataclasses
19+
import functools
1920
import re
2021
from typing import Dict, Union
2122

@@ -748,12 +749,21 @@ def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
748749
var_update_ops = [
749750
tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
750751
]
751-
# Post training step update.
752+
# Post training step update. It may update non-trainable vars, which have
753+
# EMA variables.
752754
with tf.control_dependencies(var_update_ops):
753755
post_step_op = self.PostTrainingStepUpdate()
754756

755-
train_ops = {}
757+
# EMA update after all EMA reference variables are updated.
756758
with tf.control_dependencies([post_step_op]):
759+
ema_update_op = self.ApplyExponentialMovingAverage()
760+
761+
# Post EMA update, which depends on the updated EMA vars. e.g. quant_vars
762+
with tf.control_dependencies([ema_update_op]):
763+
post_ema_op = self.PostEmaUpdate()
764+
765+
train_ops = {}
766+
with tf.control_dependencies([post_ema_op]):
757767
# Get the op to update the weight masks and thresholds
758768
mask_update_op = self._GetMaskUpdateOp()
759769
train_ops['mask_updates'] = mask_update_op
@@ -813,30 +823,28 @@ def _ComputeGradientMask(self, bprop_variable_filters):
813823
self._per_input_gradient_mask[var.name] += (
814824
tf.one_hot(i, len(bprop_variable_filters), dtype=tf.float32))
815825

816-
def ApplyExponentialMovingAverage(self, ema):
817-
"""Wraps `self.train_op` with an op updating exponential moving average."""
826+
def CreateExponentialMovingAverage(self, ema):
827+
"""Create exponential moving average variables."""
818828
if not ema:
819829
# EMA not enabled.
820830
return
821831

822-
all_vars = _VariablesForEMA(self.params, self.vars.Flatten())
823-
# For ExecutorTpu: `ema.apply()` below creates stateful variable update
824-
# operations, and due to the use of tf.function in the tpu training loop,
825-
# these update ops will be added as (implicit) control dependencies to
826-
# the step function of eval/decode program. To avoid updating EMA variables,
827-
# we run `ema.apply()` only in two cases: 1) in train program, or
828-
# 2) in the first eval or decode program when there is no train program.
829-
# It'll still apply the update in every eval/decode step even though
830-
# the update is not materialized into checkpoint, but experiment shows it
831-
# doesn't affect eval/decode metrics.
832-
if self.do_eval:
833-
need_ema_apply = any([ema.average(var) is None for var in all_vars])
834-
if need_ema_apply:
835-
assert all([ema.average(var) is None for var in all_vars
836-
]), ('We never update EMA partially.')
837-
else:
838-
# Trainer already created EMA variables.
839-
return
832+
tf.logging.info('CreateExponentialMovingAverage on %s', self)
833+
# Use empty name here so no prefix is added to the EMA variable names.
834+
# Pin EMA varialbes to CPU if needed.
835+
# The scope: GetLingvoVariableCreator(MaybePinVarsToCpu(_ApplyEMA(...)))
836+
scoped_apply_ema = self._ApplyEMA
837+
for scoped_creator in (py_utils.MaybePinVarsToCpu,
838+
py_utils.GetLingvoVariableCreator('', '')):
839+
scoped_apply_ema = functools.partial(scoped_creator, scoped_apply_ema)
840+
scoped_apply_ema(ema=ema)
841+
842+
def ApplyExponentialMovingAverage(self):
843+
"""Wraps `self.train_op` with an op updating exponential moving average."""
844+
ema = self.ema
845+
if not ema:
846+
# EMA not enabled.
847+
return tf.no_op()
840848

841849
# Make sure this is called at most once in a graph. In eager mode, the outer
842850
# tf.function will be traced multiple times in different function graphs.
@@ -849,12 +857,14 @@ def ApplyExponentialMovingAverage(self, ema):
849857
self._graphs_applied_ema.add(graph)
850858

851859
tf.logging.info('ApplyExponentialMovingAverage on %s', self)
852-
853-
def ApplyEma():
854-
with tf.name_scope('moving_average'):
855-
self._post_train_ops.append(ema.apply(all_vars))
856860
# Use empty name here so no prefix is added to the EMA variable names.
857-
py_utils.GetLingvoVariableCreator('', '')(ApplyEma)
861+
scoped_creator = py_utils.GetLingvoVariableCreator('', '')
862+
return scoped_creator(self._ApplyEMA, ema=ema)
863+
864+
def _ApplyEMA(self, ema):
865+
all_vars = _VariablesForEMA(self.params, self.vars.Flatten())
866+
with tf.name_scope('moving_average'):
867+
return ema.apply(all_vars)
858868

859869
# TODO(blee): Rename Decode->DecodeWithDefaultTheta, DecodeWithTheta->Decode.
860870
def Decode(self, input_batch):
@@ -1199,18 +1209,15 @@ def _MakeEMAVariablesDict(self):
11991209
def ConstructFPropBPropGraph(self):
12001210
raise NotImplementedError('Abstract method')
12011211

1202-
def ConstructFPropGraph(self, apply_ema=False):
1212+
def ConstructFPropGraph(self):
12031213
raise NotImplementedError('Abstract method')
12041214

1205-
def ConstructDecodeGraph(self, task_name=None, apply_ema=False):
1215+
def ConstructDecodeGraph(self, task_name=None):
12061216
raise NotImplementedError('Abstract method')
12071217

12081218
def ConstructPostTrainingLoop(self, outfeed=None):
12091219
raise NotImplementedError('Abstract method')
12101220

1211-
def ApplyExponentialMovingAverage(self):
1212-
raise NotImplementedError('Abstract method')
1213-
12141221
@property
12151222
def tasks(self):
12161223
"""Returns a list of all tasks."""
@@ -1276,6 +1283,16 @@ class SingleTaskBase(BaseModel):
12761283
def __init__(self, params, **kwargs):
12771284
super().__init__(params, **kwargs)
12781285

1286+
def _CreateLayerVariables(self) -> None:
1287+
super()._CreateLayerVariables()
1288+
# CPU evaler doesn't create EMA variables. It loads EMA variables to
1289+
# regular variables.
1290+
use_ema = self.ema and (not self.do_eval or self.use_ema_for_theta)
1291+
# All variables of the model are created. Now create EMA variables.
1292+
if use_ema:
1293+
self._task.CreateExponentialMovingAverage(self.ema)
1294+
self._MakeEMAVariablesDict()
1295+
12791296
@property
12801297
def tasks(self):
12811298
return [self._task]
@@ -1287,29 +1304,16 @@ def GetTask(self, task_name=None):
12871304
def SampleTask(self, global_step):
12881305
return self._task
12891306

1290-
def ApplyExponentialMovingAverage(self):
1291-
if self.ema:
1292-
self._task.ApplyExponentialMovingAverage(self.ema)
1293-
# ConstructFPropGraph/ConstructDecodeGraph also need this to ensure that
1294-
# ema vars are loaded from checkpoint even when no training is done.
1295-
self._MakeEMAVariablesDict()
1296-
12971307
def ConstructFPropBPropGraph(self):
1298-
self.ApplyExponentialMovingAverage()
12991308
self._task.FPropDefaultTheta()
13001309
self._task.BProp()
13011310

1302-
def ConstructFPropGraph(self, apply_ema=False):
1303-
if apply_ema:
1304-
self.ApplyExponentialMovingAverage()
1311+
def ConstructFPropGraph(self):
13051312
self._task.FPropDefaultTheta()
13061313

13071314
def ConstructDecodeGraph(self,
13081315
task_name=None,
1309-
apply_ema=False,
13101316
input_batch=None):
1311-
if apply_ema:
1312-
self.ApplyExponentialMovingAverage()
13131317
with py_utils.TaskCallScope(self._task):
13141318
if not input_batch:
13151319
input_batch = self._task.GetInputBatch()
@@ -1511,6 +1515,19 @@ def __init__(self, params, **kwargs):
15111515

15121516
self.CreateChild('task_schedule', p.task_schedule)
15131517

1518+
def _CreateLayerVariables(self) -> None:
1519+
super()._CreateLayerVariables()
1520+
# CPU evaler doesn't create EMA variables. It loads EMA variables to
1521+
# regular variables.
1522+
use_ema = self.ema and (not self.do_eval or self.use_ema_for_theta)
1523+
# All variables of the model are created. Now create EMA variables.
1524+
if use_ema:
1525+
for task_name in self.task_names:
1526+
with tf.name_scope(task_name):
1527+
task = self.GetTask(task_name)
1528+
task.CreateExponentialMovingAverage(self.ema)
1529+
self._MakeEMAVariablesDict()
1530+
15141531
def _child_variable_scope_override(self):
15151532
p = self.params
15161533
res = super()._child_variable_scope_override()
@@ -1545,33 +1562,22 @@ def SampleTask(self, global_step):
15451562
tf.logging.info('Sampled task: %s', sampled_task)
15461563
return self.children[sampled_task]
15471564

1548-
def ApplyExponentialMovingAverage(self):
1549-
if self.ema:
1550-
for task_name in self.task_names:
1551-
with tf.name_scope(task_name):
1552-
task = self.GetTask(task_name)
1553-
task.ApplyExponentialMovingAverage(self.ema)
1554-
self._MakeEMAVariablesDict()
1555-
15561565
def ConstructFPropBPropGraph(self):
15571566
for task_name in self.task_names:
15581567
with tf.name_scope(task_name):
1559-
self.ApplyExponentialMovingAverage()
15601568
task = self.GetTask(task_name)
15611569
task.FPropDefaultTheta()
15621570
task.BProp()
15631571

1564-
def ConstructFPropGraph(self, apply_ema=False):
1565-
assert not apply_ema
1572+
def ConstructFPropGraph(self):
15661573
for task_name in self.task_names:
15671574
with tf.name_scope(task_name):
15681575
task = self.GetTask(task_name)
15691576
# Note: this is for CPU-based eval only where the variables are already
15701577
# loaded as EMA variables, so we don't need to apply EMA.
15711578
task.FPropDefaultTheta()
15721579

1573-
def ConstructDecodeGraph(self, task_name=None, apply_ema=False):
1574-
assert not apply_ema
1580+
def ConstructDecodeGraph(self, task_name=None):
15751581
if not task_name:
15761582
raise ValueError(
15771583
'It can decode only one task at a time, but task_name is not set.')

lingvo/core/ema_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ def testBatchNormLayer(self):
5656
model = p.Instantiate()
5757
self.assertIsNotNone(model.ema)
5858
task = model._task
59-
task._train_op = tf.no_op()
60-
task.ApplyExponentialMovingAverage(model.ema)
6159

6260
layer = task.encoder
6361
self.assertLen(layer.vars, 4)
@@ -77,7 +75,7 @@ def testBatchNormLayer(self):
7775
self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step))
7876
self.evaluate(tf.assign(beta, beta_1))
7977
self.evaluate(tf.assign(mean, mean_1))
80-
self.evaluate(task._post_train_ops)
78+
self.evaluate(task.ApplyExponentialMovingAverage())
8179

8280
self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
8381
self.evaluate([
@@ -101,8 +99,6 @@ def testBatchNormLayer(self):
10199
model = p.Instantiate()
102100
self.assertIsNotNone(model.ema)
103101
task = model._task
104-
task._train_op = tf.no_op()
105-
task.ApplyExponentialMovingAverage(model.ema)
106102
layer = task.encoder
107103
for var in layer.vars.Flatten():
108104
self.assertIsNotNone(model.ema.average(var), msg=var.name)
@@ -155,8 +151,6 @@ def testBatchNormLayer(self):
155151
model = p.Instantiate(executor_ema=executor_ema)
156152
self.assertIsNotNone(model.ema)
157153
task = model._task
158-
task._train_op = tf.no_op()
159-
task.ApplyExponentialMovingAverage(model.ema)
160154
layer = task.encoder
161155
for var in layer.vars.Flatten():
162156
self.assertIsNotNone(model.ema.average(var), msg=var.name)

lingvo/core/program.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -798,8 +798,7 @@ def TpuEvalStep(self, *args):
798798
Summed eval metrics.
799799
"""
800800
with tf.name_scope('tpu_eval'):
801-
# Applies EMA if applicable to support running only eval/decode programs.
802-
self._model.ConstructFPropGraph(apply_ema=True)
801+
self._model.ConstructFPropGraph()
803802
per_step_eval_metrics = self._eval_metrics.SetMetrics(
804803
self._task.eval_metrics, args)
805804
summed_metrics = []
@@ -1044,7 +1043,6 @@ def __init__(self, params, **kwargs):
10441043
super().__init__(params, **kwargs)
10451044
self._program_name = 'DecodeProgram'
10461045
self._decode_out_dict_lst = []
1047-
self._ema_applied = False
10481046
self._dataset_summaries = {}
10491047
# TODO(xingwu): fully deprecate decode_until_out_of_range
10501048
if self.params.decode_until_out_of_range:
@@ -1152,11 +1150,8 @@ def DecodeFunc(self, inp_instance):
11521150

11531151
def _DecodeFn():
11541152
"""Decode call to be compiled for TPU."""
1155-
# Applies EMA if applicable to support running only eval/decode programs.
11561153
_, decode_dict = self._model.ConstructDecodeGraph(
1157-
apply_ema=(not self._ema_applied),
11581154
input_batch=inp_instance.TpuDequeueBatch())
1159-
self._ema_applied = True
11601155
self.decode_nm = py_utils.NestedMap(decode_dict)
11611156
return self.decode_nm.Flatten()
11621157

@@ -1636,8 +1631,7 @@ def DecodeFunc(self):
16361631

16371632
def _DecodeStep():
16381633
"""Decode call to be compiled for TPU."""
1639-
# Applies EMA if applicable to support running only eval/decode programs.
1640-
_, decode_dict = self._model.ConstructDecodeGraph(apply_ema=True)
1634+
_, decode_dict = self._model.ConstructDecodeGraph()
16411635
self.decode_nm = py_utils.NestedMap(decode_dict)
16421636
return [self._OutfeedEnqueue(decode_dict)]
16431637

@@ -1911,9 +1905,7 @@ def TpuTrain():
19111905
def _DecodeFn():
19121906
"""Decode call to be compiled for TPU."""
19131907
with cluster_factory.SetEval(True):
1914-
# Applies EMA if applicable to support running only eval/decode
1915-
# programs.
1916-
_, decode_dict = self._decode_model.ConstructDecodeGraph(apply_ema=True)
1908+
_, decode_dict = self._decode_model.ConstructDecodeGraph()
19171909
self.decode_nm = py_utils.NestedMap(decode_dict)
19181910
return self.decode_nm.Flatten()
19191911

0 commit comments

Comments
 (0)