3131from tensorflow .python .util .tf_export import keras_export
3232from tensorflow .python .util import serialization
3333import json
34+ import datetime
3435
3536# from six.moves import zip
3637
7374 'load_hdf5_to_weights' ,
7475 'save_hdf5_graph' ,
7576 'load_hdf5_graph' ,
76- 'net2static_graph' ,
77+ # 'net2static_graph',
7778 'static_graph2net' ,
7879 # 'save_pkl_graph',
7980 # 'load_pkl_graph',
@@ -92,29 +93,29 @@ def str2func(s):
9293 return expr
9394
9495
95- def net2static_graph (network ):
96- saved_file = dict ()
97- if network ._NameNone is True :
98- saved_file .update ({"name" : None })
99- else :
100- saved_file .update ({"name" : network .name })
101- if not isinstance (network .inputs , list ):
102- saved_file .update ({"inputs" : network .inputs ._info [0 ].name })
103- else :
104- saved_inputs = []
105- for saved_input in network .inputs :
106- saved_inputs .append (saved_input ._info [0 ].name )
107- saved_file .update ({"inputs" : saved_inputs })
108- if not isinstance (network .outputs , list ):
109- saved_file .update ({"outputs" : network .outputs ._info [0 ].name })
110- else :
111- saved_outputs = []
112- for saved_output in network .outputs :
113- saved_outputs .append (saved_output ._info [0 ].name )
114- saved_file .update ({"outputs" : saved_outputs })
115- saved_file .update ({"config" : network .config })
116-
117- return saved_file
96+ # def net2static_graph(network):
97+ # saved_file = dict()
98+ # # if network._NameNone is True:
99+ # # saved_file.update({"name": None})
100+ # # else:
101+ # # saved_file.update({"name": network.name})
102+ # # if not isinstance(network.inputs, list):
103+ # # saved_file.update({"inputs": network.inputs._info[0].name})
104+ # # else:
105+ # # saved_inputs = []
106+ # # for saved_input in network.inputs:
107+ # # saved_inputs.append(saved_input._info[0].name)
108+ # # saved_file.update({"inputs": saved_inputs})
109+ # # if not isinstance(network.outputs, list):
110+ # # saved_file.update({"outputs": network.outputs._info[0].name})
111+ # # else:
112+ # # saved_outputs = []
113+ # # for saved_output in network.outputs:
114+ # # saved_outputs.append(saved_output._info[0].name)
115+ # # saved_file.update({"outputs": saved_outputs})
116+ # saved_file.update({"config": network.config})
117+ #
118+ # return saved_file
118119
119120
120121@keras_export ('keras.models.save_model' )
@@ -149,7 +150,7 @@ def load_keras_model(model_config):
149150 return model
150151
151152
152- def save_hdf5_graph (network , filepath = 'model.hdf5' , save_weights = False ):
153+ def save_hdf5_graph (network , filepath = 'model.hdf5' , save_weights = False , customized_data = None ):
153154 """Save the architecture of TL model into a hdf5 file. Support saving model weights.
154155
155156 Parameters
@@ -160,6 +161,8 @@ def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False):
160161 The name of model file.
161162 save_weights : bool
162163 Whether to save model weights.
164+ customized_data : dict
165+ The user customized meta data.
163166
164167 Examples
165168 --------
@@ -177,11 +180,22 @@ def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False):
177180
178181 logging .info ("[*] Saving TL model into {}, saving weights={}" .format (filepath , save_weights ))
179182
180- saved_file = net2static_graph (network )
181- saved_file_str = str (saved_file )
183+ model_config = network .config # net2static_graph(network)
184+ model_config_str = str (model_config )
185+ customized_data_str = str (customized_data )
186+ version_info = {
187+ "tensorlayer_version" : tl .__version__ ,
188+ "backend" : "tensorflow" ,
189+ "backend_version" : tf .__version__ ,
190+ "training_device" : "gpu" ,
191+ "save_date" : datetime .datetime .utcnow ().replace (tzinfo = datetime .timezone .utc ).isoformat ()
192+ }
193+ version_info_str = str (version_info )
182194
183195 with h5py .File (filepath , 'w' ) as f :
184- f .attrs ["model_structure" ] = saved_file_str .encode ('utf8' )
196+ f .attrs ["model_config" ] = model_config_str .encode ('utf8' )
197+ f .attrs ["customized_data" ] = customized_data_str .encode ('utf8' )
198+ f .attrs ["version_info" ] = version_info_str .encode ('utf8' )
185199 if save_weights :
186200 _save_weights_to_hdf5_group (f , network .all_layers )
187201 f .flush ()
@@ -237,29 +251,15 @@ def eval_layer(layer_kwargs):
237251 raise RuntimeError ("Unknown layer type." )
238252
239253
240- def static_graph2net (saved_file ):
254+ def static_graph2net (model_config ):
241255 layer_dict = {}
242- model_name = saved_file ['name' ]
243- inputs_tensors = saved_file ['inputs' ]
244- outputs_tensors = saved_file ['outputs' ]
245- all_args = saved_file ['config' ]
246- tf_version = saved_file ['config' ].pop (0 )['tf_version' ]
247- tl_version = saved_file ['config' ].pop (0 )['tl_version' ]
248- if tf_version != tf .__version__ :
249- logging .warning (
250- "Saved model uses tensorflow version {}, but now you are using tensorflow version {}" .format (
251- tf_version , tf .__version__
252- )
253- )
254- if tl_version != tl .__version__ :
255- logging .warning (
256- "Saved model uses tensorlayer version {}, but now you are using tensorlayer version {}" .format (
257- tl_version , tl .__version__
258- )
259- )
256+ model_name = model_config ["name" ]
257+ inputs_tensors = model_config ["inputs" ]
258+ outputs_tensors = model_config ["outputs" ]
259+ all_args = model_config ["model_architecture" ]
260260 for idx , layer_kwargs in enumerate (all_args ):
261- layer_class = layer_kwargs [' class' ] # class of current layer
262- prev_layers = layer_kwargs .pop (' prev_layer' ) # name of previous layers
261+ layer_class = layer_kwargs [" class" ] # class of current layer
262+ prev_layers = layer_kwargs .pop (" prev_layer" ) # name of previous layers
263263 net = eval_layer (layer_kwargs )
264264 if layer_class in tl .layers .inputs .__all__ :
265265 net = net ._nodes [0 ].out_tensors [0 ]
@@ -312,11 +312,30 @@ def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
312312 - see ``tl.files.save_hdf5_graph``
313313 """
314314 logging .info ("[*] Loading TL model from {}, loading weights={}" .format (filepath , load_weights ))
315+
315316 f = h5py .File (filepath , 'r' )
316- saved_file_str = f .attrs ["model_structure" ].decode ('utf8' )
317- saved_file = eval (saved_file_str )
318317
319- M = static_graph2net (saved_file )
318+ version_info_str = f .attrs ["version_info" ].decode ('utf8' )
319+ version_info = eval (version_info_str )
320+ backend_version = version_info ["backend_version" ]
321+ tensorlayer_version = version_info ["tensorlayer_version" ]
322+ if backend_version != tf .__version__ :
323+ logging .warning (
324+ "Saved model uses tensorflow version {}, but now you are using tensorflow version {}" .format (
325+ backend_version , tf .__version__
326+ )
327+ )
328+ if tensorlayer_version != tl .__version__ :
329+ logging .warning (
330+ "Saved model uses tensorlayer version {}, but now you are using tensorlayer version {}" .format (
331+ tensorlayer_version , tl .__version__
332+ )
333+ )
334+
335+ model_config_str = f .attrs ["model_config" ].decode ('utf8' )
336+ model_config = eval (model_config_str )
337+
338+ M = static_graph2net (model_config )
320339 if load_weights :
321340 if not ('layer_names' in f .attrs .keys ()):
322341 raise RuntimeError ("Saved model does not contain weights." )
@@ -329,55 +348,55 @@ def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
329348 return M
330349
331350
332- def load_pkl_graph (name = 'model.pkl' ):
333- """Restore TL model archtecture from a a pickle file. No parameters be restored.
334-
335- Parameters
336- -----------
337- name : str
338- The name of graph file.
339-
340- Returns
341- --------
342- network : TensorLayer Model.
343-
344- Examples
345- --------
346- >>> # It is better to use load_hdf5_graph
347- """
348- logging .info ("[*] Loading TL graph from {}" .format (name ))
349- with open (name , 'rb' ) as file :
350- saved_file = pickle .load (file )
351-
352- M = static_graph2net (saved_file )
353-
354- return M
355-
356-
357- def save_pkl_graph (network , name = 'model.pkl' ):
358- """Save the architecture of TL model into a pickle file. No parameters be saved.
359-
360- Parameters
361- -----------
362- network : TensorLayer layer
363- The network to save.
364- name : str
365- The name of graph file.
366-
367- Example
368- --------
369- >>> # It is better to use save_hdf5_graph
370- """
371- if network .outputs is None :
372- raise AssertionError ("save_graph not support dynamic mode yet" )
373-
374- logging .info ("[*] Saving TL graph into {}" .format (name ))
375-
376- saved_file = net2static_graph (network )
377-
378- with open (name , 'wb' ) as file :
379- pickle .dump (saved_file , file , protocol = pickle .HIGHEST_PROTOCOL )
380- logging .info ("[*] Saved graph" )
351+ # def load_pkl_graph(name='model.pkl'):
352+ # """Restore TL model archtecture from a a pickle file. No parameters be restored.
353+ #
354+ # Parameters
355+ # -----------
356+ # name : str
357+ # The name of graph file.
358+ #
359+ # Returns
360+ # --------
361+ # network : TensorLayer Model.
362+ #
363+ # Examples
364+ # --------
365+ # >>> # It is better to use load_hdf5_graph
366+ # """
367+ # logging.info("[*] Loading TL graph from {}".format(name))
368+ # with open(name, 'rb') as file:
369+ # saved_file = pickle.load(file)
370+ #
371+ # M = static_graph2net(saved_file)
372+ #
373+ # return M
374+ #
375+ #
376+ # def save_pkl_graph(network, name='model.pkl'):
377+ # """Save the architecture of TL model into a pickle file. No parameters be saved.
378+ #
379+ # Parameters
380+ # -----------
381+ # network : TensorLayer layer
382+ # The network to save.
383+ # name : str
384+ # The name of graph file.
385+ #
386+ # Example
387+ # --------
388+ # >>> # It is better to use save_hdf5_graph
389+ # """
390+ # if network.outputs is None:
391+ # raise AssertionError("save_graph not support dynamic mode yet")
392+ #
393+ # logging.info("[*] Saving TL graph into {}".format(name))
394+ #
395+ # saved_file = net2static_graph(network)
396+ #
397+ # with open(name, 'wb') as file:
398+ # pickle.dump(saved_file, file, protocol=pickle.HIGHEST_PROTOCOL)
399+ # logging.info("[*] Saved graph")
381400
382401
383402# Load dataset functions
0 commit comments