Skip to content

Commit 9add5c7

Browse files
authored
Add parameters save to HDF5 (#22)
1 parent 33cee36 commit 9add5c7

File tree

2 files changed

+76
-254
lines changed

2 files changed

+76
-254
lines changed

tensorlayerx/files/utils.py

Lines changed: 65 additions & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,6 @@ def load_keras_model(model_config):
117117
'load_hdf5_to_weights',
118118
'save_hdf5_graph',
119119
'load_hdf5_graph',
120-
# 'net2static_graph',
121-
'static_graph2net',
122-
# 'save_pkl_graph',
123-
# 'load_pkl_graph',
124120
'load_and_assign_ckpt',
125121
'ckpt_to_npz_dict'
126122
]
@@ -138,62 +134,6 @@ def str2func(s):
138134
return expr
139135

140136

141-
# def net2static_graph(network):
142-
# saved_file = dict()
143-
# # if network._NameNone is True:
144-
# # saved_file.update({"name": None})
145-
# # else:
146-
# # saved_file.update({"name": network.name})
147-
# # if not isinstance(network.inputs, list):
148-
# # saved_file.update({"inputs": network.inputs._info[0].name})
149-
# # else:
150-
# # saved_inputs = []
151-
# # for saved_input in network.inputs:
152-
# # saved_inputs.append(saved_input._info[0].name)
153-
# # saved_file.update({"inputs": saved_inputs})
154-
# # if not isinstance(network.outputs, list):
155-
# # saved_file.update({"outputs": network.outputs._info[0].name})
156-
# # else:
157-
# # saved_outputs = []
158-
# # for saved_output in network.outputs:
159-
# # saved_outputs.append(saved_output._info[0].name)
160-
# # saved_file.update({"outputs": saved_outputs})
161-
# saved_file.update({"config": network.config})
162-
#
163-
# return saved_file
164-
165-
# @keras_export('keras.model.save_model')
166-
# def save_keras_model(model):
167-
# # f.attrs['keras_model_config'] = json.dumps(
168-
# # {
169-
# # 'class_name': model.__class__.__name__,
170-
# # 'config': model.get_config()
171-
# # },
172-
# # default=serialization.get_json_type).encode('utf8')
173-
# #
174-
# # f.flush()
175-
#
176-
# return json.dumps(
177-
# {
178-
# 'class_name': model.__class__.__name__,
179-
# 'config': model.get_config()
180-
# }, default=serialization.get_json_type
181-
# ).encode('utf8')
182-
#
183-
#
184-
# @keras_export('keras.model.load_model')
185-
# def load_keras_model(model_config):
186-
#
187-
# custom_objects = {}
188-
#
189-
# if model_config is None:
190-
# raise ValueError('No model found in config.')
191-
# model_config = json.loads(model_config.decode('utf-8'))
192-
# model = model_config_lib.model_from_config(model_config, custom_objects=custom_objects)
193-
#
194-
# return model
195-
196-
197137
def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False, customized_data=None):
198138
"""Save the architecture of TL model into a hdf5 file. Support saving model weights.
199139
@@ -229,19 +169,10 @@ def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False, customiz
229169
).isoformat()
230170
model_config_str = str(model_config)
231171
customized_data_str = str(customized_data)
232-
# version_info = {
233-
# "tensorlayerx_version": tlx.__version__,
234-
# "backend": "tensorflow",
235-
# "backend_version": tf.__version__,
236-
# "training_device": "gpu",
237-
# "save_date": datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
238-
# }
239-
# version_info_str = str(version_info)
240172

241173
with h5py.File(filepath, 'w') as f:
242174
f.attrs["model_config"] = model_config_str.encode('utf8')
243175
f.attrs["customized_data"] = customized_data_str.encode('utf8')
244-
# f.attrs["version_info"] = version_info_str.encode('utf8')
245176
if save_weights:
246177
_save_weights_to_hdf5_group(f, network.all_layers)
247178
f.flush()
@@ -267,78 +198,6 @@ def generate_func(args):
267198
# fn = str2func(args[key])
268199
# args[key] = fn
269200

270-
271-
def eval_layer(layer_kwargs):
272-
layer_class = layer_kwargs.pop('class')
273-
args = layer_kwargs['args']
274-
layer_type = args.pop('layer_type')
275-
if layer_type == "normal":
276-
generate_func(args)
277-
return eval('tlx.layers.' + layer_class)(**args)
278-
elif layer_type == "layerlist":
279-
ret_layer = []
280-
layers = args["layers"]
281-
for layer_graph in layers:
282-
ret_layer.append(eval_layer(layer_graph))
283-
args['layers'] = ret_layer
284-
return eval('tlx.layers.' + layer_class)(**args)
285-
elif layer_type == "modellayer":
286-
M = static_graph2net(args['model'])
287-
args['model'] = M
288-
return eval('tlx.layers.' + layer_class)(**args)
289-
elif layer_type == "keraslayer":
290-
M = load_keras_model(args['fn'])
291-
input_shape = args.pop('keras_input_shape')
292-
_ = M(np.random.random(input_shape).astype(np.float32))
293-
args['fn'] = M
294-
args['fn_weights'] = M.trainable_variables
295-
return eval('tlx.layers.' + layer_class)(**args)
296-
else:
297-
raise RuntimeError("Unknown layer type.")
298-
299-
300-
def static_graph2net(model_config):
301-
layer_dict = {}
302-
model_name = model_config["name"]
303-
inputs_tensors = model_config["inputs"]
304-
outputs_tensors = model_config["outputs"]
305-
all_args = model_config["model_architecture"]
306-
for idx, layer_kwargs in enumerate(all_args):
307-
layer_class = layer_kwargs["class"] # class of current layer
308-
prev_layers = layer_kwargs.pop("prev_layer") # name of previous layers
309-
net = eval_layer(layer_kwargs)
310-
if layer_class in tlx.nn.inputs.__all__:
311-
net = net._nodes[0].out_tensors[0]
312-
if prev_layers is not None:
313-
for prev_layer in prev_layers:
314-
if not isinstance(prev_layer, list):
315-
output = net(layer_dict[prev_layer])
316-
layer_dict[output._info[0].name] = output
317-
else:
318-
list_layers = [layer_dict[layer] for layer in prev_layer]
319-
output = net(list_layers)
320-
layer_dict[output._info[0].name] = output
321-
else:
322-
layer_dict[net._info[0].name] = net
323-
324-
if not isinstance(inputs_tensors, list):
325-
model_inputs = layer_dict[inputs_tensors]
326-
else:
327-
model_inputs = []
328-
for inputs_tensor in inputs_tensors:
329-
model_inputs.append(layer_dict[inputs_tensor])
330-
if not isinstance(outputs_tensors, list):
331-
model_outputs = layer_dict[outputs_tensors]
332-
else:
333-
model_outputs = []
334-
for outputs_tensor in outputs_tensors:
335-
model_outputs.append(layer_dict[outputs_tensor])
336-
from tensorlayerx.model import Model
337-
M = Model(inputs=model_inputs, outputs=model_outputs, name=model_name)
338-
logging.info("[*] Load graph finished")
339-
return M
340-
341-
342201
def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
343202
"""Restore TL model archtecture from a a pickle file. Support loading model weights.
344203
@@ -395,58 +254,6 @@ def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
395254
return M
396255

397256

398-
# def load_pkl_graph(name='model.pkl'):
399-
# """Restore TL model archtecture from a a pickle file. No parameters be restored.
400-
#
401-
# Parameters
402-
# -----------
403-
# name : str
404-
# The name of graph file.
405-
#
406-
# Returns
407-
# --------
408-
# network : TensorLayer Model.
409-
#
410-
# Examples
411-
# --------
412-
# >>> # It is better to use load_hdf5_graph
413-
# """
414-
# logging.info("[*] Loading TL graph from {}".format(name))
415-
# with open(name, 'rb') as file:
416-
# saved_file = pickle.load(file)
417-
#
418-
# M = static_graph2net(saved_file)
419-
#
420-
# return M
421-
#
422-
#
423-
# def save_pkl_graph(network, name='model.pkl'):
424-
# """Save the architecture of TL model into a pickle file. No parameters be saved.
425-
#
426-
# Parameters
427-
# -----------
428-
# network : TensorLayer layer
429-
# The network to save.
430-
# name : str
431-
# The name of graph file.
432-
#
433-
# Example
434-
# --------
435-
# >>> # It is better to use save_hdf5_graph
436-
# """
437-
# if network.outputs is None:
438-
# raise AssertionError("save_graph not support dynamic mode yet")
439-
#
440-
# logging.info("[*] Saving TL graph into {}".format(name))
441-
#
442-
# saved_file = net2static_graph(network)
443-
#
444-
# with open(name, 'wb') as file:
445-
# pickle.dump(saved_file, file, protocol=pickle.HIGHEST_PROTOCOL)
446-
# logging.info("[*] Saved graph")
447-
448-
449-
# Load dataset functions
450257
def load_mnist_dataset(shape=(-1, 784), path='data'):
451258
"""Load the original mnist.
452259
@@ -2716,7 +2523,7 @@ def assign_th_variable(variable, value):
27162523
variable.data = torch.as_tensor(value)
27172524

27182525

2719-
def _save_weights_to_hdf5_group(f, layers):
2526+
def _save_weights_to_hdf5_group(f, save_list):
27202527
"""
27212528
Save layer/model weights into hdf5 group recursively.
27222529
@@ -2728,33 +2535,37 @@ def _save_weights_to_hdf5_group(f, layers):
27282535
A list of layers to save weights.
27292536
27302537
"""
2731-
f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
27322538

2733-
for layer in layers:
2734-
g = f.create_group(layer.name)
2735-
if isinstance(layer, tlx.model.Model):
2736-
_save_weights_to_hdf5_group(g, layer.all_layers)
2737-
elif isinstance(layer, tlx.nn.ModelLayer):
2738-
_save_weights_to_hdf5_group(g, layer.model.all_layers)
2739-
elif isinstance(layer, tlx.nn.ModuleList):
2740-
_save_weights_to_hdf5_group(g, layer.layers)
2741-
elif isinstance(layer, tlx.nn.Layer):
2742-
if layer.all_weights is not None:
2743-
weight_values = tf_variables_to_numpy(layer.all_weights)
2744-
weight_names = [w.name.encode('utf8') for w in layer.all_weights]
2745-
else:
2746-
weight_values = []
2747-
weight_names = []
2748-
g.attrs['weight_names'] = weight_names
2749-
for name, val in zip(weight_names, weight_values):
2750-
val_dataset = g.create_dataset(name, val.shape, dtype=val.dtype)
2751-
if not val.shape:
2752-
# scalar
2753-
val_dataset[()] = val
2754-
else:
2755-
val_dataset[:] = val
2756-
else:
2757-
raise Exception("Only layer or model can be saved into hdf5.")
2539+
if save_list is None:
2540+
save_list = []
2541+
if tlx.BACKEND != 'torch':
2542+
save_list_names = [tensor.name for tensor in save_list]
2543+
2544+
if tlx.BACKEND == 'tensorflow':
2545+
save_list_var = tf_variables_to_numpy(save_list)
2546+
elif tlx.BACKEND == 'mindspore':
2547+
save_list_var = ms_variables_to_numpy(save_list)
2548+
elif tlx.BACKEND == 'paddle':
2549+
save_list_var = pd_variables_to_numpy(save_list)
2550+
elif tlx.BACKEND == 'torch':
2551+
save_list_names = []
2552+
save_list_var = []
2553+
for named, values in save_list:
2554+
save_list_names.append(named)
2555+
save_list_var.append(values.cpu().detach().numpy())
2556+
else:
2557+
raise NotImplementedError('Not implemented')
2558+
save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
2559+
2560+
g = f.create_group('model_parameters')
2561+
for k in save_var_dict.keys():
2562+
val_dataset = g.create_dataset('.'.join(k.split('/')), data=save_var_dict[k])
2563+
2564+
save_list_var = None
2565+
save_var_dict = None
2566+
del save_list_var
2567+
del save_var_dict
2568+
logging.info("[*] Model saved in hdf5.")
27582569

27592570

27602571
def _load_weights_from_hdf5_group_in_order(f, layers):
@@ -2838,7 +2649,7 @@ def _load_weights_from_hdf5_group(f, layers, skip=False):
28382649
raise Exception("Only layer or model can be saved into hdf5.")
28392650

28402651

2841-
def save_weights_to_hdf5(filepath, network):
2652+
def save_weights_to_hdf5(save_list, filepath):
28422653
"""Input filepath and save weights in hdf5 format.
28432654
28442655
Parameters
@@ -2855,12 +2666,12 @@ def save_weights_to_hdf5(filepath, network):
28552666
logging.info("[*] Saving TL weights into %s" % filepath)
28562667

28572668
with h5py.File(filepath, 'w') as f:
2858-
_save_weights_to_hdf5_group(f, network.all_layers)
2669+
_save_weights_to_hdf5_group(f, save_list)
28592670

28602671
logging.info("[*] Saved")
28612672

28622673

2863-
def load_hdf5_to_weights_in_order(filepath, network):
2674+
def load_hdf5_to_weights_in_order(filepath, network, skip=False):
28642675
"""Load weights sequentially from a given file of hdf5 format
28652676
28662677
Parameters
@@ -2879,23 +2690,39 @@ def load_hdf5_to_weights_in_order(filepath, network):
28792690
28802691
"""
28812692
f = h5py.File(filepath, 'r')
2882-
try:
2883-
layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]]
2884-
except Exception:
2885-
raise NameError(
2886-
"The loaded hdf5 file needs to have 'layer_names' as attributes. "
2887-
"Please check whether this hdf5 file is saved from TL."
2888-
)
2889-
2890-
if len(network.all_layers) != len(layer_names):
2891-
logging.warning(
2892-
"Number of weights mismatch."
2893-
"Trying to load a saved file with " + str(len(layer_names)) + " layers into a model with " +
2894-
str(len(network.all_layers)) + " layers."
2895-
)
2693+
weights = f['model_parameters']
2694+
print(weights.keys())
2695+
if len(weights.keys()) != len(set(weights.keys())):
2696+
raise Exception("Duplication in model npz_dict %s" % name)
28962697

2897-
_load_weights_from_hdf5_group_in_order(f, network.all_layers)
2698+
if tlx.BACKEND == 'torch':
2699+
net_weights_name = [n for n, v in network.named_parameters()]
2700+
torch_weights_dict = {n: v for n, v in network.named_parameters()}
2701+
else:
2702+
net_weights_name = [w.name for w in network.all_weights]
28982703

2704+
for key in weights.keys():
2705+
key_t = '/'.join(key.split('.'))
2706+
if key_t not in net_weights_name:
2707+
if skip:
2708+
logging.warning("Weights named '%s' not found in network. Skip it." % key)
2709+
else:
2710+
raise RuntimeError(
2711+
"Weights named '%s' not found in network. Hint: set argument skip=Ture "
2712+
"if you want to skip redundant or mismatch weights." % key
2713+
)
2714+
else:
2715+
if tlx.BACKEND == 'tensorflow':
2716+
assign_tf_variable(network.all_weights[net_weights_name.index(key_t)], weights[key])
2717+
elif tlx.BACKEND == 'mindspore':
2718+
assign_param = Tensor(weights[key], dtype=ms.float32)
2719+
assign_ms_variable(network.all_weights[net_weights_name.index(key_t)], assign_param)
2720+
elif tlx.BACKEND == 'paddle':
2721+
assign_pd_variable(network.all_weights[net_weights_name.index(key_t)], weights[key])
2722+
elif tlx.BACKEND == 'torch':
2723+
assign_th_variable(torch_weights_dict[key_t], weights[key])
2724+
else:
2725+
raise NotImplementedError('Not implemented')
28992726
f.close()
29002727
logging.info("[*] Load %s SUCCESS!" % filepath)
29012728

0 commit comments

Comments
 (0)