@@ -3470,7 +3470,7 @@ class CustomCheckpointLoader:
34703470  """ 
34713471
34723472  def  __init__ (self , filename , saveable_params , params_prefix = "" , load_if_prefix = "" , ignore_missing = False ,
3473-                ignore_params = (), ignore_params_prefixes = (),
3473+                ignore_params = (), ignore_params_prefixes = (),  var_name_mapping = None , 
34743474               network = None ):
34753475    """ 
34763476    :param str filename: filepattern for NewCheckpointReader 
@@ -3482,13 +3482,16 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34823482      however, if there is no single var in the checkpoint, this is still an error. 
34833483    :param typing.Container[str] ignore_params: these param (by name) will not be loaded 
34843484    :param typing.Iterable[str] ignore_params_prefixes: these param (by prefix name) will not be loaded 
3485+     :param dict[str,str] var_name_mapping: defines a custom mapping (new_name -> name_in_checkpoint) for 
3486+       renamed vars in the checkpoint 
34853487    :param TFNetwork network: 
34863488    """ 
34873489    self .filename  =  filename 
34883490    self .network  =  network 
34893491    self .ignore_missing  =  ignore_missing 
34903492    self .params_prefix  =  params_prefix 
34913493    self .load_if_prefix  =  load_if_prefix 
3494+     self .var_name_mapping  =  var_name_mapping  or  {}
34923495    self .saveable_params  =  []
34933496    count  =  0 
34943497    for  param  in  saveable_params :
@@ -3541,6 +3544,7 @@ def __init__(self, layer, checkpoint_loader):
35413544      self .layer  =  layer 
35423545      self .prefix_param_name  =  layer .get_absolute_name_scope_prefix ()
35433546      self .checkpoint_param_names  =  []
3547+       self .var_name_mapping  =  checkpoint_loader .var_name_mapping 
35443548      prefix  =  self .prefix_param_name 
35453549      # Collect checkpoint params, and remove them from the lists. 
35463550      for  name  in  list (checkpoint_loader .var_ckpt_names ):
@@ -3576,7 +3580,7 @@ def assign_var(self, var, session):
35763580        return 
35773581      self .assigned  =  True 
35783582      values_dict  =  {
3579-         name : self .reader .get_tensor (self .prefix_param_name  +  name )
3583+         name : self .reader .get_tensor (self .var_name_mapping . get ( name ,  self . prefix_param_name  +  name ) )
35803584        for  name  in  self .checkpoint_param_names }
35813585      self .reader  =  None   # Allow GC now, we do not need it anymore. 
35823586      print ("Custom param import of layer %r with original params %r."  %  (
@@ -3894,6 +3898,7 @@ def get_lazy_dict(self):
38943898      if  v .endswith (MakeLoadCudnnRnn .cudnn_postfix ):
38953899        var_name_map .update (
38963900          MakeLoadCudnnRnn (prefix = v [:- len (MakeLoadCudnnRnn .cudnn_postfix ) +  1 ]).get_lazy_dict ())
3901+     var_name_map .update ({name : make_load_renamed (old_name ) for  name , old_name  in  self .var_name_mapping .items ()})
38973902
38983903    could_not_find_map_list  =  [v  for  v  in  missing_var_names  if  v  not  in   var_name_map ]
38993904    if  self .ignore_missing  or  not  could_not_find_map_list :
0 commit comments