Skip to content

Commit 299046e

Browse files
warshallrhozsdonghao
authored andcommitted
change network config format, and update layer act (#980)
* (non)trainable weights, layer all_layers * weights -> all_weights * weights -> all_weights, trainable weights, nontrainable_weights * fix bugs, yapf * fix bugs * fix bugs * fix bugs * alpha version, update network config * fix bug * add files * Update CHANGELOG.md * fix bugs * yapf * update act in base layer and related layers * fix bugs * fix bug * fix bugs * parse float in lrelu * yapf
1 parent 96b520e commit 299046e

31 files changed

+389
-190
lines changed

CHANGELOG.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,21 @@ To release a new version, please update the changelog as followed:
6767

6868
<!-- YOU CAN EDIT FROM HERE -->
6969

70+
## [2.1.0] - 2019-5-25
71+
72+
### Changed
73+
- change the format of network config, change related code and files; change layer act (PR #980)
74+
75+
### Added
76+
77+
### Dependencies Update
78+
79+
### Fixed
80+
81+
### Contributors
82+
- @warshallrho: #PR980
83+
84+
7085
## [2.0.1] - 2019-5-17
7186

7287

@@ -460,4 +475,4 @@ To many PR for this update, please check [here](https://github.com/tensorlayer/t
460475
[1.10.0]: https://github.com/tensorlayer/tensorlayer/compare/1.9.1...1.10.0
461476
[1.9.1]: https://github.com/tensorlayer/tensorlayer/compare/1.9.0...1.9.1
462477
[1.9.0]: https://github.com/tensorlayer/tensorlayer/compare/1.8.5...1.9.0
463-
[1.8.5]: https://github.com/tensorlayer/tensorlayer/compare/1.8.4...1.8.5
478+
[1.8.5]: https://github.com/tensorlayer/tensorlayer/compare/1.8.4...1.8.5

docs/user/get_start_model.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,9 @@ z = f(x*W+b)
211211
212212
class Dense(Layer):
213213
def __init__(self, n_units, act=None, in_channels=None, name=None):
214-
super(Dense, self).__init__(name)
214+
super(Dense, self).__init__(name, act=act)
215215
216216
self.n_units = n_units
217-
self.act = act
218217
self.in_channels = in_channels
219218
220219
# for dynamic model, it needs the input shape to get the shape of W

examples/text_generation/tutorial_generate_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,4 +329,4 @@ def main_lstm_generate_text():
329329
# main_restore_embedding_layer()
330330

331331
# How to generate text from a given context
332-
main_lstm_generate_text()
332+
main_lstm_generate_text()

tensorlayer/db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import tensorflow as tf
1414

1515
from tensorlayer import logging
16-
from tensorlayer.files import net2static_graph, static_graph2net, assign_weights
16+
from tensorlayer.files import static_graph2net, assign_weights
1717
from tensorlayer.files import save_weights_to_hdf5, load_hdf5_to_weights
1818
from tensorlayer.files import del_folder, exists_or_mkdir
1919

@@ -153,7 +153,7 @@ def save_model(self, network=None, model_name='model', **kwargs):
153153
s = time.time()
154154

155155
# kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
156-
kwargs.update({'architecture': net2static_graph(network), 'time': datetime.utcnow()})
156+
kwargs.update({'architecture': network.config, 'time': datetime.utcnow()})
157157

158158
try:
159159
params_id = self.model_fs.put(self._serialization(params))

tensorlayer/files/utils.py

Lines changed: 120 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorflow.python.util.tf_export import keras_export
3232
from tensorflow.python.util import serialization
3333
import json
34+
import datetime
3435

3536
# from six.moves import zip
3637

@@ -73,7 +74,7 @@
7374
'load_hdf5_to_weights',
7475
'save_hdf5_graph',
7576
'load_hdf5_graph',
76-
'net2static_graph',
77+
# 'net2static_graph',
7778
'static_graph2net',
7879
# 'save_pkl_graph',
7980
# 'load_pkl_graph',
@@ -92,29 +93,29 @@ def str2func(s):
9293
return expr
9394

9495

95-
def net2static_graph(network):
96-
saved_file = dict()
97-
if network._NameNone is True:
98-
saved_file.update({"name": None})
99-
else:
100-
saved_file.update({"name": network.name})
101-
if not isinstance(network.inputs, list):
102-
saved_file.update({"inputs": network.inputs._info[0].name})
103-
else:
104-
saved_inputs = []
105-
for saved_input in network.inputs:
106-
saved_inputs.append(saved_input._info[0].name)
107-
saved_file.update({"inputs": saved_inputs})
108-
if not isinstance(network.outputs, list):
109-
saved_file.update({"outputs": network.outputs._info[0].name})
110-
else:
111-
saved_outputs = []
112-
for saved_output in network.outputs:
113-
saved_outputs.append(saved_output._info[0].name)
114-
saved_file.update({"outputs": saved_outputs})
115-
saved_file.update({"config": network.config})
116-
117-
return saved_file
96+
# def net2static_graph(network):
97+
# saved_file = dict()
98+
# # if network._NameNone is True:
99+
# # saved_file.update({"name": None})
100+
# # else:
101+
# # saved_file.update({"name": network.name})
102+
# # if not isinstance(network.inputs, list):
103+
# # saved_file.update({"inputs": network.inputs._info[0].name})
104+
# # else:
105+
# # saved_inputs = []
106+
# # for saved_input in network.inputs:
107+
# # saved_inputs.append(saved_input._info[0].name)
108+
# # saved_file.update({"inputs": saved_inputs})
109+
# # if not isinstance(network.outputs, list):
110+
# # saved_file.update({"outputs": network.outputs._info[0].name})
111+
# # else:
112+
# # saved_outputs = []
113+
# # for saved_output in network.outputs:
114+
# # saved_outputs.append(saved_output._info[0].name)
115+
# # saved_file.update({"outputs": saved_outputs})
116+
# saved_file.update({"config": network.config})
117+
#
118+
# return saved_file
118119

119120

120121
@keras_export('keras.models.save_model')
@@ -149,7 +150,7 @@ def load_keras_model(model_config):
149150
return model
150151

151152

152-
def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False):
153+
def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False, customized_data=None):
153154
"""Save the architecture of TL model into a hdf5 file. Support saving model weights.
154155
155156
Parameters
@@ -160,6 +161,8 @@ def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False):
160161
The name of model file.
161162
save_weights : bool
162163
Whether to save model weights.
164+
customized_data : dict
165+
The user customized meta data.
163166
164167
Examples
165168
--------
@@ -177,11 +180,22 @@ def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False):
177180

178181
logging.info("[*] Saving TL model into {}, saving weights={}".format(filepath, save_weights))
179182

180-
saved_file = net2static_graph(network)
181-
saved_file_str = str(saved_file)
183+
model_config = network.config # net2static_graph(network)
184+
model_config_str = str(model_config)
185+
customized_data_str = str(customized_data)
186+
version_info = {
187+
"tensorlayer_version": tl.__version__,
188+
"backend": "tensorflow",
189+
"backend_version": tf.__version__,
190+
"training_device": "gpu",
191+
"save_date": datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
192+
}
193+
version_info_str = str(version_info)
182194

183195
with h5py.File(filepath, 'w') as f:
184-
f.attrs["model_structure"] = saved_file_str.encode('utf8')
196+
f.attrs["model_config"] = model_config_str.encode('utf8')
197+
f.attrs["customized_data"] = customized_data_str.encode('utf8')
198+
f.attrs["version_info"] = version_info_str.encode('utf8')
185199
if save_weights:
186200
_save_weights_to_hdf5_group(f, network.all_layers)
187201
f.flush()
@@ -237,29 +251,15 @@ def eval_layer(layer_kwargs):
237251
raise RuntimeError("Unknown layer type.")
238252

239253

240-
def static_graph2net(saved_file):
254+
def static_graph2net(model_config):
241255
layer_dict = {}
242-
model_name = saved_file['name']
243-
inputs_tensors = saved_file['inputs']
244-
outputs_tensors = saved_file['outputs']
245-
all_args = saved_file['config']
246-
tf_version = saved_file['config'].pop(0)['tf_version']
247-
tl_version = saved_file['config'].pop(0)['tl_version']
248-
if tf_version != tf.__version__:
249-
logging.warning(
250-
"Saved model uses tensorflow version {}, but now you are using tensorflow version {}".format(
251-
tf_version, tf.__version__
252-
)
253-
)
254-
if tl_version != tl.__version__:
255-
logging.warning(
256-
"Saved model uses tensorlayer version {}, but now you are using tensorlayer version {}".format(
257-
tl_version, tl.__version__
258-
)
259-
)
256+
model_name = model_config["name"]
257+
inputs_tensors = model_config["inputs"]
258+
outputs_tensors = model_config["outputs"]
259+
all_args = model_config["model_architecture"]
260260
for idx, layer_kwargs in enumerate(all_args):
261-
layer_class = layer_kwargs['class'] # class of current layer
262-
prev_layers = layer_kwargs.pop('prev_layer') # name of previous layers
261+
layer_class = layer_kwargs["class"] # class of current layer
262+
prev_layers = layer_kwargs.pop("prev_layer") # name of previous layers
263263
net = eval_layer(layer_kwargs)
264264
if layer_class in tl.layers.inputs.__all__:
265265
net = net._nodes[0].out_tensors[0]
@@ -312,11 +312,30 @@ def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
312312
- see ``tl.files.save_hdf5_graph``
313313
"""
314314
logging.info("[*] Loading TL model from {}, loading weights={}".format(filepath, load_weights))
315+
315316
f = h5py.File(filepath, 'r')
316-
saved_file_str = f.attrs["model_structure"].decode('utf8')
317-
saved_file = eval(saved_file_str)
318317

319-
M = static_graph2net(saved_file)
318+
version_info_str = f.attrs["version_info"].decode('utf8')
319+
version_info = eval(version_info_str)
320+
backend_version = version_info["backend_version"]
321+
tensorlayer_version = version_info["tensorlayer_version"]
322+
if backend_version != tf.__version__:
323+
logging.warning(
324+
"Saved model uses tensorflow version {}, but now you are using tensorflow version {}".format(
325+
backend_version, tf.__version__
326+
)
327+
)
328+
if tensorlayer_version != tl.__version__:
329+
logging.warning(
330+
"Saved model uses tensorlayer version {}, but now you are using tensorlayer version {}".format(
331+
tensorlayer_version, tl.__version__
332+
)
333+
)
334+
335+
model_config_str = f.attrs["model_config"].decode('utf8')
336+
model_config = eval(model_config_str)
337+
338+
M = static_graph2net(model_config)
320339
if load_weights:
321340
if not ('layer_names' in f.attrs.keys()):
322341
raise RuntimeError("Saved model does not contain weights.")
@@ -329,55 +348,55 @@ def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
329348
return M
330349

331350

332-
def load_pkl_graph(name='model.pkl'):
333-
"""Restore TL model archtecture from a a pickle file. No parameters be restored.
334-
335-
Parameters
336-
-----------
337-
name : str
338-
The name of graph file.
339-
340-
Returns
341-
--------
342-
network : TensorLayer Model.
343-
344-
Examples
345-
--------
346-
>>> # It is better to use load_hdf5_graph
347-
"""
348-
logging.info("[*] Loading TL graph from {}".format(name))
349-
with open(name, 'rb') as file:
350-
saved_file = pickle.load(file)
351-
352-
M = static_graph2net(saved_file)
353-
354-
return M
355-
356-
357-
def save_pkl_graph(network, name='model.pkl'):
358-
"""Save the architecture of TL model into a pickle file. No parameters be saved.
359-
360-
Parameters
361-
-----------
362-
network : TensorLayer layer
363-
The network to save.
364-
name : str
365-
The name of graph file.
366-
367-
Example
368-
--------
369-
>>> # It is better to use save_hdf5_graph
370-
"""
371-
if network.outputs is None:
372-
raise AssertionError("save_graph not support dynamic mode yet")
373-
374-
logging.info("[*] Saving TL graph into {}".format(name))
375-
376-
saved_file = net2static_graph(network)
377-
378-
with open(name, 'wb') as file:
379-
pickle.dump(saved_file, file, protocol=pickle.HIGHEST_PROTOCOL)
380-
logging.info("[*] Saved graph")
351+
# def load_pkl_graph(name='model.pkl'):
352+
# """Restore TL model archtecture from a a pickle file. No parameters be restored.
353+
#
354+
# Parameters
355+
# -----------
356+
# name : str
357+
# The name of graph file.
358+
#
359+
# Returns
360+
# --------
361+
# network : TensorLayer Model.
362+
#
363+
# Examples
364+
# --------
365+
# >>> # It is better to use load_hdf5_graph
366+
# """
367+
# logging.info("[*] Loading TL graph from {}".format(name))
368+
# with open(name, 'rb') as file:
369+
# saved_file = pickle.load(file)
370+
#
371+
# M = static_graph2net(saved_file)
372+
#
373+
# return M
374+
#
375+
#
376+
# def save_pkl_graph(network, name='model.pkl'):
377+
# """Save the architecture of TL model into a pickle file. No parameters be saved.
378+
#
379+
# Parameters
380+
# -----------
381+
# network : TensorLayer layer
382+
# The network to save.
383+
# name : str
384+
# The name of graph file.
385+
#
386+
# Example
387+
# --------
388+
# >>> # It is better to use save_hdf5_graph
389+
# """
390+
# if network.outputs is None:
391+
# raise AssertionError("save_graph not support dynamic mode yet")
392+
#
393+
# logging.info("[*] Saving TL graph into {}".format(name))
394+
#
395+
# saved_file = net2static_graph(network)
396+
#
397+
# with open(name, 'wb') as file:
398+
# pickle.dump(saved_file, file, protocol=pickle.HIGHEST_PROTOCOL)
399+
# logging.info("[*] Saved graph")
381400

382401

383402
# Load dataset functions

tensorlayer/layers/convolution/binary_conv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,10 @@ def __init__(
7575
in_channels=None,
7676
name=None # 'binary_cnn2d',
7777
):
78-
super().__init__(name)
78+
super().__init__(name, act=act)
7979
self.n_filter = n_filter
8080
self.filter_size = filter_size
8181
self.strides = self._strides = strides
82-
self.act = act
8382
self.padding = padding
8483
self.use_gemm = use_gemm
8584
self.data_format = data_format

tensorlayer/layers/convolution/deformable_conv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,11 @@ def __init__(
8383
in_channels=None,
8484
name=None # 'deformable_conv_2d',
8585
):
86-
super().__init__(name)
86+
super().__init__(name, act=act)
8787

8888
self.offset_layer = offset_layer
8989
self.n_filter = n_filter
9090
self.filter_size = filter_size
91-
self.act = act
9291
self.padding = padding
9392
self.W_init = W_init
9493
self.b_init = b_init

tensorlayer/layers/convolution/depthwise_conv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,9 @@ def __init__(
8282
in_channels=None,
8383
name=None # 'depthwise_conv2d'
8484
):
85-
super().__init__(name)
85+
super().__init__(name, act=act)
8686
self.filter_size = filter_size
8787
self.strides = self._strides = strides
88-
self.act = act
8988
self.padding = padding
9089
self.dilation_rate = self._dilation_rate = dilation_rate
9190
self.data_format = data_format

0 commit comments

Comments
 (0)