@@ -117,10 +117,6 @@ def load_keras_model(model_config):
117117 'load_hdf5_to_weights' ,
118118 'save_hdf5_graph' ,
119119 'load_hdf5_graph' ,
120- # 'net2static_graph',
121- 'static_graph2net' ,
122- # 'save_pkl_graph',
123- # 'load_pkl_graph',
124120 'load_and_assign_ckpt' ,
125121 'ckpt_to_npz_dict'
126122]
@@ -138,62 +134,6 @@ def str2func(s):
138134 return expr
139135
140136
141- # def net2static_graph(network):
142- # saved_file = dict()
143- # # if network._NameNone is True:
144- # # saved_file.update({"name": None})
145- # # else:
146- # # saved_file.update({"name": network.name})
147- # # if not isinstance(network.inputs, list):
148- # # saved_file.update({"inputs": network.inputs._info[0].name})
149- # # else:
150- # # saved_inputs = []
151- # # for saved_input in network.inputs:
152- # # saved_inputs.append(saved_input._info[0].name)
153- # # saved_file.update({"inputs": saved_inputs})
154- # # if not isinstance(network.outputs, list):
155- # # saved_file.update({"outputs": network.outputs._info[0].name})
156- # # else:
157- # # saved_outputs = []
158- # # for saved_output in network.outputs:
159- # # saved_outputs.append(saved_output._info[0].name)
160- # # saved_file.update({"outputs": saved_outputs})
161- # saved_file.update({"config": network.config})
162- #
163- # return saved_file
164-
165- # @keras_export('keras.model.save_model')
166- # def save_keras_model(model):
167- # # f.attrs['keras_model_config'] = json.dumps(
168- # # {
169- # # 'class_name': model.__class__.__name__,
170- # # 'config': model.get_config()
171- # # },
172- # # default=serialization.get_json_type).encode('utf8')
173- # #
174- # # f.flush()
175- #
176- # return json.dumps(
177- # {
178- # 'class_name': model.__class__.__name__,
179- # 'config': model.get_config()
180- # }, default=serialization.get_json_type
181- # ).encode('utf8')
182- #
183- #
184- # @keras_export('keras.model.load_model')
185- # def load_keras_model(model_config):
186- #
187- # custom_objects = {}
188- #
189- # if model_config is None:
190- # raise ValueError('No model found in config.')
191- # model_config = json.loads(model_config.decode('utf-8'))
192- # model = model_config_lib.model_from_config(model_config, custom_objects=custom_objects)
193- #
194- # return model
195-
196-
197137def save_hdf5_graph (network , filepath = 'model.hdf5' , save_weights = False , customized_data = None ):
198138 """Save the architecture of TL model into a hdf5 file. Support saving model weights.
199139
@@ -229,19 +169,10 @@ def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False, customiz
229169 ).isoformat ()
230170 model_config_str = str (model_config )
231171 customized_data_str = str (customized_data )
232- # version_info = {
233- # "tensorlayerx_version": tlx.__version__,
234- # "backend": "tensorflow",
235- # "backend_version": tf.__version__,
236- # "training_device": "gpu",
237- # "save_date": datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
238- # }
239- # version_info_str = str(version_info)
240172
241173 with h5py .File (filepath , 'w' ) as f :
242174 f .attrs ["model_config" ] = model_config_str .encode ('utf8' )
243175 f .attrs ["customized_data" ] = customized_data_str .encode ('utf8' )
244- # f.attrs["version_info"] = version_info_str.encode('utf8')
245176 if save_weights :
246177 _save_weights_to_hdf5_group (f , network .all_layers )
247178 f .flush ()
@@ -267,78 +198,6 @@ def generate_func(args):
267198 # fn = str2func(args[key])
268199 # args[key] = fn
269200
270-
271- def eval_layer (layer_kwargs ):
272- layer_class = layer_kwargs .pop ('class' )
273- args = layer_kwargs ['args' ]
274- layer_type = args .pop ('layer_type' )
275- if layer_type == "normal" :
276- generate_func (args )
277- return eval ('tlx.layers.' + layer_class )(** args )
278- elif layer_type == "layerlist" :
279- ret_layer = []
280- layers = args ["layers" ]
281- for layer_graph in layers :
282- ret_layer .append (eval_layer (layer_graph ))
283- args ['layers' ] = ret_layer
284- return eval ('tlx.layers.' + layer_class )(** args )
285- elif layer_type == "modellayer" :
286- M = static_graph2net (args ['model' ])
287- args ['model' ] = M
288- return eval ('tlx.layers.' + layer_class )(** args )
289- elif layer_type == "keraslayer" :
290- M = load_keras_model (args ['fn' ])
291- input_shape = args .pop ('keras_input_shape' )
292- _ = M (np .random .random (input_shape ).astype (np .float32 ))
293- args ['fn' ] = M
294- args ['fn_weights' ] = M .trainable_variables
295- return eval ('tlx.layers.' + layer_class )(** args )
296- else :
297- raise RuntimeError ("Unknown layer type." )
298-
299-
300- def static_graph2net (model_config ):
301- layer_dict = {}
302- model_name = model_config ["name" ]
303- inputs_tensors = model_config ["inputs" ]
304- outputs_tensors = model_config ["outputs" ]
305- all_args = model_config ["model_architecture" ]
306- for idx , layer_kwargs in enumerate (all_args ):
307- layer_class = layer_kwargs ["class" ] # class of current layer
308- prev_layers = layer_kwargs .pop ("prev_layer" ) # name of previous layers
309- net = eval_layer (layer_kwargs )
310- if layer_class in tlx .nn .inputs .__all__ :
311- net = net ._nodes [0 ].out_tensors [0 ]
312- if prev_layers is not None :
313- for prev_layer in prev_layers :
314- if not isinstance (prev_layer , list ):
315- output = net (layer_dict [prev_layer ])
316- layer_dict [output ._info [0 ].name ] = output
317- else :
318- list_layers = [layer_dict [layer ] for layer in prev_layer ]
319- output = net (list_layers )
320- layer_dict [output ._info [0 ].name ] = output
321- else :
322- layer_dict [net ._info [0 ].name ] = net
323-
324- if not isinstance (inputs_tensors , list ):
325- model_inputs = layer_dict [inputs_tensors ]
326- else :
327- model_inputs = []
328- for inputs_tensor in inputs_tensors :
329- model_inputs .append (layer_dict [inputs_tensor ])
330- if not isinstance (outputs_tensors , list ):
331- model_outputs = layer_dict [outputs_tensors ]
332- else :
333- model_outputs = []
334- for outputs_tensor in outputs_tensors :
335- model_outputs .append (layer_dict [outputs_tensor ])
336- from tensorlayerx .model import Model
337- M = Model (inputs = model_inputs , outputs = model_outputs , name = model_name )
338- logging .info ("[*] Load graph finished" )
339- return M
340-
341-
342201def load_hdf5_graph (filepath = 'model.hdf5' , load_weights = False ):
343202 """Restore TL model archtecture from a a pickle file. Support loading model weights.
344203
@@ -395,58 +254,6 @@ def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
395254 return M
396255
397256
398- # def load_pkl_graph(name='model.pkl'):
399- # """Restore TL model archtecture from a a pickle file. No parameters be restored.
400- #
401- # Parameters
402- # -----------
403- # name : str
404- # The name of graph file.
405- #
406- # Returns
407- # --------
408- # network : TensorLayer Model.
409- #
410- # Examples
411- # --------
412- # >>> # It is better to use load_hdf5_graph
413- # """
414- # logging.info("[*] Loading TL graph from {}".format(name))
415- # with open(name, 'rb') as file:
416- # saved_file = pickle.load(file)
417- #
418- # M = static_graph2net(saved_file)
419- #
420- # return M
421- #
422- #
423- # def save_pkl_graph(network, name='model.pkl'):
424- # """Save the architecture of TL model into a pickle file. No parameters be saved.
425- #
426- # Parameters
427- # -----------
428- # network : TensorLayer layer
429- # The network to save.
430- # name : str
431- # The name of graph file.
432- #
433- # Example
434- # --------
435- # >>> # It is better to use save_hdf5_graph
436- # """
437- # if network.outputs is None:
438- # raise AssertionError("save_graph not support dynamic mode yet")
439- #
440- # logging.info("[*] Saving TL graph into {}".format(name))
441- #
442- # saved_file = net2static_graph(network)
443- #
444- # with open(name, 'wb') as file:
445- # pickle.dump(saved_file, file, protocol=pickle.HIGHEST_PROTOCOL)
446- # logging.info("[*] Saved graph")
447-
448-
449- # Load dataset functions
450257def load_mnist_dataset (shape = (- 1 , 784 ), path = 'data' ):
451258 """Load the original mnist.
452259
@@ -2716,7 +2523,7 @@ def assign_th_variable(variable, value):
27162523 variable .data = torch .as_tensor (value )
27172524
27182525
2719- def _save_weights_to_hdf5_group (f , layers ):
2526+ def _save_weights_to_hdf5_group (f , save_list ):
27202527 """
27212528 Save layer/model weights into hdf5 group recursively.
27222529
@@ -2728,33 +2535,37 @@ def _save_weights_to_hdf5_group(f, layers):
27282535 A list of layers to save weights.
27292536
27302537 """
2731- f .attrs ['layer_names' ] = [layer .name .encode ('utf8' ) for layer in layers ]
27322538
2733- for layer in layers :
2734- g = f .create_group (layer .name )
2735- if isinstance (layer , tlx .model .Model ):
2736- _save_weights_to_hdf5_group (g , layer .all_layers )
2737- elif isinstance (layer , tlx .nn .ModelLayer ):
2738- _save_weights_to_hdf5_group (g , layer .model .all_layers )
2739- elif isinstance (layer , tlx .nn .ModuleList ):
2740- _save_weights_to_hdf5_group (g , layer .layers )
2741- elif isinstance (layer , tlx .nn .Layer ):
2742- if layer .all_weights is not None :
2743- weight_values = tf_variables_to_numpy (layer .all_weights )
2744- weight_names = [w .name .encode ('utf8' ) for w in layer .all_weights ]
2745- else :
2746- weight_values = []
2747- weight_names = []
2748- g .attrs ['weight_names' ] = weight_names
2749- for name , val in zip (weight_names , weight_values ):
2750- val_dataset = g .create_dataset (name , val .shape , dtype = val .dtype )
2751- if not val .shape :
2752- # scalar
2753- val_dataset [()] = val
2754- else :
2755- val_dataset [:] = val
2756- else :
2757- raise Exception ("Only layer or model can be saved into hdf5." )
2539+ if save_list is None :
2540+ save_list = []
2541+ if tlx .BACKEND != 'torch' :
2542+ save_list_names = [tensor .name for tensor in save_list ]
2543+
2544+ if tlx .BACKEND == 'tensorflow' :
2545+ save_list_var = tf_variables_to_numpy (save_list )
2546+ elif tlx .BACKEND == 'mindspore' :
2547+ save_list_var = ms_variables_to_numpy (save_list )
2548+ elif tlx .BACKEND == 'paddle' :
2549+ save_list_var = pd_variables_to_numpy (save_list )
2550+ elif tlx .BACKEND == 'torch' :
2551+ save_list_names = []
2552+ save_list_var = []
2553+ for named , values in save_list :
2554+ save_list_names .append (named )
2555+ save_list_var .append (values .cpu ().detach ().numpy ())
2556+ else :
2557+ raise NotImplementedError ('Not implemented' )
2558+ save_var_dict = {save_list_names [idx ]: val for idx , val in enumerate (save_list_var )}
2559+
2560+ g = f .create_group ('model_parameters' )
2561+ for k in save_var_dict .keys ():
2562+ val_dataset = g .create_dataset ('.' .join (k .split ('/' )), data = save_var_dict [k ])
2563+
2564+ save_list_var = None
2565+ save_var_dict = None
2566+ del save_list_var
2567+ del save_var_dict
2568+ logging .info ("[*] Model saved in hdf5." )
27582569
27592570
27602571def _load_weights_from_hdf5_group_in_order (f , layers ):
@@ -2838,7 +2649,7 @@ def _load_weights_from_hdf5_group(f, layers, skip=False):
28382649 raise Exception ("Only layer or model can be saved into hdf5." )
28392650
28402651
2841- def save_weights_to_hdf5 (filepath , network ):
2652+ def save_weights_to_hdf5 (save_list , filepath ):
28422653 """Input filepath and save weights in hdf5 format.
28432654
28442655 Parameters
@@ -2855,12 +2666,12 @@ def save_weights_to_hdf5(filepath, network):
28552666 logging .info ("[*] Saving TL weights into %s" % filepath )
28562667
28572668 with h5py .File (filepath , 'w' ) as f :
2858- _save_weights_to_hdf5_group (f , network . all_layers )
2669+ _save_weights_to_hdf5_group (f , save_list )
28592670
28602671 logging .info ("[*] Saved" )
28612672
28622673
2863- def load_hdf5_to_weights_in_order (filepath , network ):
2674+ def load_hdf5_to_weights_in_order (filepath , network , skip = False ):
28642675 """Load weights sequentially from a given file of hdf5 format
28652676
28662677 Parameters
@@ -2879,23 +2690,39 @@ def load_hdf5_to_weights_in_order(filepath, network):
28792690
28802691 """
28812692 f = h5py .File (filepath , 'r' )
2882- try :
2883- layer_names = [n .decode ('utf8' ) for n in f .attrs ["layer_names" ]]
2884- except Exception :
2885- raise NameError (
2886- "The loaded hdf5 file needs to have 'layer_names' as attributes. "
2887- "Please check whether this hdf5 file is saved from TL."
2888- )
2889-
2890- if len (network .all_layers ) != len (layer_names ):
2891- logging .warning (
2892- "Number of weights mismatch."
2893- "Trying to load a saved file with " + str (len (layer_names )) + " layers into a model with " +
2894- str (len (network .all_layers )) + " layers."
2895- )
2693+ weights = f ['model_parameters' ]
2694+ print (weights .keys ())
2695+ if len (weights .keys ()) != len (set (weights .keys ())):
2696+ raise Exception ("Duplication in model npz_dict %s" % name )
28962697
2897- _load_weights_from_hdf5_group_in_order (f , network .all_layers )
2698+ if tlx .BACKEND == 'torch' :
2699+ net_weights_name = [n for n , v in network .named_parameters ()]
2700+ torch_weights_dict = {n : v for n , v in network .named_parameters ()}
2701+ else :
2702+ net_weights_name = [w .name for w in network .all_weights ]
28982703
2704+ for key in weights .keys ():
2705+ key_t = '/' .join (key .split ('.' ))
2706+ if key_t not in net_weights_name :
2707+ if skip :
2708+ logging .warning ("Weights named '%s' not found in network. Skip it." % key )
2709+ else :
2710+ raise RuntimeError (
2711+ "Weights named '%s' not found in network. Hint: set argument skip=Ture "
2712+ "if you want to skip redundant or mismatch weights." % key
2713+ )
2714+ else :
2715+ if tlx .BACKEND == 'tensorflow' :
2716+ assign_tf_variable (network .all_weights [net_weights_name .index (key_t )], weights [key ])
2717+ elif tlx .BACKEND == 'mindspore' :
2718+ assign_param = Tensor (weights [key ], dtype = ms .float32 )
2719+ assign_ms_variable (network .all_weights [net_weights_name .index (key_t )], assign_param )
2720+ elif tlx .BACKEND == 'paddle' :
2721+ assign_pd_variable (network .all_weights [net_weights_name .index (key_t )], weights [key ])
2722+ elif tlx .BACKEND == 'torch' :
2723+ assign_th_variable (torch_weights_dict [key_t ], weights [key ])
2724+ else :
2725+ raise NotImplementedError ('Not implemented' )
28992726 f .close ()
29002727 logging .info ("[*] Load %s SUCCESS!" % filepath )
29012728
0 commit comments