1616import  returnn .tf .compat  as  tf_compat 
1717import  returnn .tf .util .basic  as  tf_util 
1818from  returnn .tf .util .basic  import  Data , DimensionTag , reuse_name_scope , VariableAssigner 
19+ from  returnn .util  import  basic  as  util 
1920
2021
2122class  DataNotFound (Exception ):
@@ -3473,7 +3474,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34733474               ignore_params = (), ignore_params_prefixes = (), var_name_mapping = None ,
34743475               network = None ):
34753476    """ 
3476-     :param str filename: filepattern for NewCheckpointReader 
3477+     :param str filename: filepattern for NewCheckpointReader or .index/.meta file path  
34773478    :param list[tf.Variable|tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject] saveable_params: 
34783479    :param str params_prefix: expect that all vars in saveable_params have this prefix, and remove it 
34793480    :param str load_if_prefix: if given, only load variables with a name containing this string. 
@@ -3486,7 +3487,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34863487      renamed vars in the checkpoint 
34873488    :param TFNetwork network: 
34883489    """ 
3489-     self .filename  =  filename 
3490+     self .filepattern  =  util . get_checkpoint_filepattern ( filename ) 
34903491    self .network  =  network 
34913492    self .ignore_missing  =  ignore_missing 
34923493    self .params_prefix  =  params_prefix 
@@ -3510,7 +3511,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
35103511        continue 
35113512      self .saveable_params .append (param )
35123513    assert  count  >  0 , "%s: no saveable vars"  %  self 
3513-     self .reader  =  tf_compat .v1 .train .NewCheckpointReader (filename )
3514+     self .reader  =  tf_compat .v1 .train .NewCheckpointReader (self . filepattern )
35143515    self .net_vars  =  [v  for  v  in  self .saveable_params  if  isinstance (v , tf .Variable )]
35153516    self .net_saveables  =  [v  for  v  in  self .saveable_params  if  not  isinstance (v , tf .Variable )]
35163517    # All variables in the checkpoint: 
@@ -3918,7 +3919,7 @@ def get_lazy_dict(self):
39183919          if  self .ignore_missing  and  v_name  not  in var_name_map :
39193920            print (
39203921              "Warning, did not find match for var %r (%r, params_prefix %r, load_if_prefix %r) in checkpoint %r."  %  (
3921-                 v , v_name , self .params_prefix , self .load_if_prefix , self .filename ), file = log .v3 )
3922+                 v , v_name , self .params_prefix , self .load_if_prefix , self .filepattern ), file = log .v3 )
39223923            continue 
39233924          variable_values [v ] =  self .VariableValue (value = var_name_map [v_name ]())
39243925      assert  variable_values , "no vars to load; saveable vars are %r. load_if_prefix %r."  %  (
0 commit comments