Skip to content

Commit 0ca2767

Browse files
ds-hwangcopybara-github
authored andcommitted
Reuse the mechanism of py_utils to have empty name for EMA variable_scope.
Currently, EMA variable scope name is the same to tf.get_variable_scope(). The implementation relies on who is caller. If the caller's variable_scope has a name, the code assumption is broken. py_utils.CreateVariable has same requirement, which removes variable_scope prefix (e.g. train, evaler_cpu). So py_utils.GetLingvoVariableCreator() is used when creating tf.variable, which works no matter graph or eager mode. Let EMA variable scope reuse py_utils.GetLingvoVariableCreator(). EMA variable is just variable, so it doesn't need special treatment. In addition, remove unused _EMAVariableScope(). PiperOrigin-RevId: 487879527
1 parent 2c75e99 commit 0ca2767

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

lingvo/core/base_layer.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -893,23 +893,6 @@ def _SelfVariableScope(self, params=None, enter_name_scope=True):
893893
tf.name_scope(self._self_variable_scope.original_name_scope))
894894
yield stack
895895

896-
@contextlib.contextmanager
897-
def _EMAVariableScope(self):
898-
"""Internal. Used to ensure that the EMA variable names are reused."""
899-
if not hasattr(self, '_ema_variable_scope'):
900-
# Use empty name here so no prefix is added to the EMA variable names.
901-
# This ensures compatibility between Graph and Eager checkpoint names.
902-
# https://www.tensorflow.org/api_docs/python/tf/compat/v1/variable_scope
903-
with tf.variable_scope('') as scope:
904-
self._ema_variable_scope = scope
905-
with contextlib.ExitStack() as stack:
906-
# Entering a variable scope so that the EMA variables can reuse the names
907-
# in separate trainer runs. This turned out to be useful in the case of
908-
# model fine tuning, when the same model is run multiple times in the same
909-
# Eager context.
910-
stack.enter_context(tf.variable_scope(self._ema_variable_scope))
911-
yield stack
912-
913896
def _child_variable_scope_override(self):
914897
"""Override the variable scope for individual children.
915898

lingvo/core/base_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Base model."""
1616

1717
import collections
18-
import contextlib
1918
import dataclasses
2019
import re
2120
from typing import Dict, Union
@@ -850,11 +849,12 @@ def ApplyExponentialMovingAverage(self, ema):
850849
self._graphs_applied_ema.add(graph)
851850

852851
tf.logging.info('ApplyExponentialMovingAverage on %s', self)
853-
with contextlib.ExitStack() as context_stack:
854-
if py_utils.IsEagerMode():
855-
context_stack.enter_context(self._EMAVariableScope())
852+
853+
def ApplyEma():
856854
with tf.name_scope('moving_average'):
857855
self._post_train_ops.append(ema.apply(all_vars))
856+
# Use empty name here so no prefix is added to the EMA variable names.
857+
py_utils.GetLingvoVariableCreator('', '')(ApplyEma)
858858

859859
# TODO(blee): Rename Decode->DecodeWithDefaultTheta, DecodeWithTheta->Decode.
860860
def Decode(self, input_batch):

0 commit comments

Comments
 (0)