Skip to content

Commit b2b5cb5

Browse files
authored
allow passing a custom variable name mapping for preloading of parameters (#667)
1 parent 21b8934 commit b2b5cb5

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

returnn/tf/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,8 @@ def init_network_from_config(self, config=None, net_dict_post_proc=None):
11351135
params_prefix=self_prefix, load_if_prefix=load_if_prefix,
11361136
ignore_missing=opts.get("ignore_missing", False),
11371137
ignore_params=opts.get("ignore_params", ()),
1138-
ignore_params_prefixes=opts.get("ignore_params_prefixes", ()))
1138+
ignore_params_prefixes=opts.get("ignore_params_prefixes", ()),
1139+
var_name_mapping=opts.get("var_name_mapping", {}))
11391140
# `set_as_custom_init` is also a marker for the vars, that they are preloaded,
11401141
# such that further checkpoint loaders will not load them again.
11411142
loader.set_as_custom_init()

returnn/tf/network.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3470,7 +3470,7 @@ class CustomCheckpointLoader:
34703470
"""
34713471

34723472
def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="", ignore_missing=False,
3473-
ignore_params=(), ignore_params_prefixes=(),
3473+
ignore_params=(), ignore_params_prefixes=(), var_name_mapping=None,
34743474
network=None):
34753475
"""
34763476
:param str filename: filepattern for NewCheckpointReader
@@ -3482,13 +3482,16 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34823482
however, if there is no single var in the checkpoint, this is still an error.
34833483
:param typing.Container[str] ignore_params: these param (by name) will not be loaded
34843484
:param typing.Iterable[str] ignore_params_prefixes: these param (by prefix name) will not be loaded
3485+
:param dict[str,str] var_name_mapping: defines a custom mapping (new_name -> name_in_checkpoint) for
3486+
renamed vars in the checkpoint
34853487
:param TFNetwork network:
34863488
"""
34873489
self.filename = filename
34883490
self.network = network
34893491
self.ignore_missing = ignore_missing
34903492
self.params_prefix = params_prefix
34913493
self.load_if_prefix = load_if_prefix
3494+
self.var_name_mapping = var_name_mapping or {}
34923495
self.saveable_params = []
34933496
count = 0
34943497
for param in saveable_params:
@@ -3541,6 +3544,7 @@ def __init__(self, layer, checkpoint_loader):
35413544
self.layer = layer
35423545
self.prefix_param_name = layer.get_absolute_name_scope_prefix()
35433546
self.checkpoint_param_names = []
3547+
self.var_name_mapping = checkpoint_loader.var_name_mapping
35443548
prefix = self.prefix_param_name
35453549
# Collect checkpoint params, and remove them from the lists.
35463550
for name in list(checkpoint_loader.var_ckpt_names):
@@ -3576,7 +3580,7 @@ def assign_var(self, var, session):
35763580
return
35773581
self.assigned = True
35783582
values_dict = {
3579-
name: self.reader.get_tensor(self.prefix_param_name + name)
3583+
name: self.reader.get_tensor(self.var_name_mapping.get(name, self.prefix_param_name + name))
35803584
for name in self.checkpoint_param_names}
35813585
self.reader = None # Allow GC now, we do not need it anymore.
35823586
print("Custom param import of layer %r with original params %r." % (
@@ -3894,6 +3898,7 @@ def get_lazy_dict(self):
38943898
if v.endswith(MakeLoadCudnnRnn.cudnn_postfix):
38953899
var_name_map.update(
38963900
MakeLoadCudnnRnn(prefix=v[:-len(MakeLoadCudnnRnn.cudnn_postfix) + 1]).get_lazy_dict())
3901+
var_name_map.update({name: make_load_renamed(old_name) for name, old_name in self.var_name_mapping.items()})
38973902

38983903
could_not_find_map_list = [v for v in missing_var_names if v not in var_name_map]
38993904
if self.ignore_missing or not could_not_find_map_list:

0 commit comments

Comments
 (0)