Skip to content

Commit 9a823ba

Browse files
lingvo-botcopybara-github
authored andcommitted
Update the V1 Eager checkpointer path so it's compatible with the Graph mode trainer
PiperOrigin-RevId: 476806931
1 parent 7ee3b36 commit 9a823ba

File tree

3 files changed

+3
-7
lines changed

3 files changed

+3
-7
lines changed

lingvo/core/checkpointer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,7 @@ def __init__(self,
541541
super().__init__(train_dir, models, train_params, save_only,
542542
check_loading_status)
543543
tf.logging.info('Starting eager checkpointer v1')
544-
# Distinct from EagerCheckpointerV2
545-
self._train_dir = os.path.join(self._train_dir, 'ckpt_V1')
544+
self._train_dir = self._train_dir
546545
if not tf.io.gfile.exists(self._train_dir):
547546
tf.io.gfile.makedirs(self._train_dir)
548547

lingvo/core/checkpointer_eager_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def testEagerMultiLearnerCheckpointCompatibility(self):
177177
checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0)
178178
checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0)
179179
eager_v1_keys = _GetCheckpointKeys(
180-
os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000'))
180+
os.path.join(eager_v1_logdir, 'ckpt-00000000'))
181181
eager_v2_keys = _GetCheckpointKeys(
182182
os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0'))
183183
# Expecting two more variables in V2 checkpoints:

lingvo/core/checkpointer_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,7 @@ def testSaveOnly(self):
254254

255255
def testSortCheckpointPaths(self):
256256
ckpts_no_pad = ['/ckpt_V2/ckpt-100', '/ckpt_V2/ckpt-20', '/ckpt_V2/ckpt-3']
257-
ckpts_pad = [
258-
'/ckpt_V1/ckpt-00000100', '/ckpt_V1/ckpt-00000020',
259-
'/ckpt_V1/ckpt-00000003'
260-
]
257+
ckpts_pad = ['/ckpt-00000100', '/ckpt-00000020', '/ckpt-00000003']
261258
ckpts_no_pad_sorted = checkpointer.SortCheckpointPaths(ckpts_no_pad)
262259
self.assertIn('3', ckpts_no_pad_sorted[0])
263260
self.assertIn('100', ckpts_no_pad_sorted[2])

0 commit comments

Comments
 (0)