Skip to content

Commit e2012c4

Browse files
authored
consistent checkpoint filepattern (#671)
1 parent 3f1554d commit e2012c4

File tree

3 files changed

+28
-20
lines changed

3 files changed

+28
-20
lines changed

returnn/engine/base.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import os
99
import sys
1010
import typing
11-
from returnn.util.basic import BackendEngine, model_epoch_from_filename, get_model_filename_postfix
1211
from returnn.log import log
1312
from returnn.pretrain import Pretrain
13+
from returnn.util import basic as util
1414

1515

1616
class EngineBase(object):
@@ -52,7 +52,7 @@ def get_existing_models(cls, config):
5252
if os.path.exists(fn):
5353
file_list[epoch] = fn
5454
break
55-
if BackendEngine.is_tensorflow_selected():
55+
if util.BackendEngine.is_tensorflow_selected():
5656
if os.path.exists(fn + ".index"):
5757
file_list[epoch] = fn
5858
break
@@ -72,32 +72,24 @@ def get_epoch_model(cls, config):
7272
start_epoch = int(start_epoch_mode)
7373
assert start_epoch >= 1
7474

75-
load_model_epoch_filename = config.value('load', '')
76-
if load_model_epoch_filename.endswith(".meta"):
77-
load_model_epoch_filename = load_model_epoch_filename[:-len(".meta")]
78-
elif load_model_epoch_filename.endswith(".index"):
79-
load_model_epoch_filename = load_model_epoch_filename[:-len(".index")]
75+
load_model_epoch_filename = util.get_checkpoint_filepattern(config.value('load', ''))
8076
if load_model_epoch_filename:
81-
assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())
77+
assert os.path.exists(load_model_epoch_filename + util.get_model_filename_postfix())
8278

83-
import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
84-
if import_model_train_epoch1.endswith(".meta"):
85-
import_model_train_epoch1 = import_model_train_epoch1[:-len(".meta")]
86-
elif import_model_train_epoch1.endswith(".index"):
87-
import_model_train_epoch1 = import_model_train_epoch1[:-len(".index")]
79+
import_model_train_epoch1 = util.get_checkpoint_filepattern(config.value('import_model_train_epoch1', ''))
8880
if import_model_train_epoch1:
89-
assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())
81+
assert os.path.exists(import_model_train_epoch1 + util.get_model_filename_postfix())
9082

9183
existing_models = cls.get_existing_models(config)
9284
load_epoch = config.int("load_epoch", -1)
9385
if load_model_epoch_filename:
9486
if load_epoch <= 0:
95-
load_epoch = model_epoch_from_filename(load_model_epoch_filename)
87+
load_epoch = util.model_epoch_from_filename(load_model_epoch_filename)
9688
else:
9789
if load_epoch > 0: # ignore if load_epoch == 0
9890
assert load_epoch in existing_models
9991
load_model_epoch_filename = existing_models[load_epoch]
100-
assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch
92+
assert util.model_epoch_from_filename(load_model_epoch_filename) == load_epoch
10193

10294
# Only use this when we don't train.
10395
# For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.

returnn/tf/network.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import returnn.tf.compat as tf_compat
1717
import returnn.tf.util.basic as tf_util
1818
from returnn.tf.util.basic import Data, DimensionTag, reuse_name_scope, VariableAssigner
19+
from returnn.util import basic as util
1920

2021

2122
class DataNotFound(Exception):
@@ -3473,7 +3474,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34733474
ignore_params=(), ignore_params_prefixes=(), var_name_mapping=None,
34743475
network=None):
34753476
"""
3476-
:param str filename: filepattern for NewCheckpointReader
3477+
:param str filename: filepattern for NewCheckpointReader or .index/.meta file path
34773478
:param list[tf.Variable|tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject] saveable_params:
34783479
:param str params_prefix: expect that all vars in saveable_params have this prefix, and remove it
34793480
:param str load_if_prefix: if given, only load variables with a name containing this string.
@@ -3486,7 +3487,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34863487
renamed vars in the checkpoint
34873488
:param TFNetwork network:
34883489
"""
3489-
self.filename = filename
3490+
self.filepattern = util.get_checkpoint_filepattern(filename)
34903491
self.network = network
34913492
self.ignore_missing = ignore_missing
34923493
self.params_prefix = params_prefix
@@ -3510,7 +3511,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
35103511
continue
35113512
self.saveable_params.append(param)
35123513
assert count > 0, "%s: no saveable vars" % self
3513-
self.reader = tf_compat.v1.train.NewCheckpointReader(filename)
3514+
self.reader = tf_compat.v1.train.NewCheckpointReader(self.filepattern)
35143515
self.net_vars = [v for v in self.saveable_params if isinstance(v, tf.Variable)]
35153516
self.net_saveables = [v for v in self.saveable_params if not isinstance(v, tf.Variable)]
35163517
# All variables in the checkpoint:
@@ -3918,7 +3919,7 @@ def get_lazy_dict(self):
39183919
if self.ignore_missing and v_name not in var_name_map:
39193920
print(
39203921
"Warning, did not find match for var %r (%r, params_prefix %r, load_if_prefix %r) in checkpoint %r." % (
3921-
v, v_name, self.params_prefix, self.load_if_prefix, self.filename), file=log.v3)
3922+
v, v_name, self.params_prefix, self.load_if_prefix, self.filepattern), file=log.v3)
39223923
continue
39233924
variable_values[v] = self.VariableValue(value=var_name_map[v_name]())
39243925
assert variable_values, "no vars to load; saveable vars are %r. load_if_prefix %r." % (

returnn/util/basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,21 @@ def get_model_filename_postfix():
288288
return ""
289289

290290

291+
def get_checkpoint_filepattern(filepath):
292+
"""
293+
Removes optional .index or .meta extension
294+
295+
:param str filepath:
296+
:return: CheckpointLoader compatible filepattern
297+
:rtype: str
298+
"""
299+
if filepath.endswith(".meta"):
300+
return filepath[:-len(".meta")]
301+
elif filepath.endswith(".index"):
302+
return filepath[:-len(".index")]
303+
return filepath
304+
305+
291306
def sys_cmd_out_lines(s):
292307
"""
293308
:param str s: shell command

0 commit comments

Comments
 (0)