Skip to content

Commit 09970c7

Browse files
committed
[Files] load and assign model from npz
1 parent aee9291 commit 09970c7

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

docs/modules/files.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ sake of cross-platform.
4444
save_npz
4545
load_npz
4646
assign_params
47+
load_and_assign_npz
4748

4849
save_any_to_npy
4950
load_npy_to_any
@@ -103,6 +104,10 @@ Assign parameters to network
103104
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
104105
.. autofunction:: assign_params
105106

107+
Load and assign parameters to network
108+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
109+
.. autofunction:: load_and_assign_npz
110+
106111
Load and save variables
107112
------------------------
108113

tensorlayer/files.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,31 @@ def assign_params(sess, params, network):
634634
ops.append(network.all_params[idx].assign(param))
635635
sess.run(ops)
636636

637+
def load_and_assign_npz(sess=None, name=None, network=None):
638+
"""Load model from npz and assign to a network.
637639
640+
Parameters
641+
-------------
642+
sess : TensorFlow Session
643+
name : string
644+
Model path.
645+
network : a :class:`Layer` class
646+
The network to be assigned
647+
648+
Examples
649+
---------
650+
>>> tl.files.load_and_assign_npz(sess=sess, name='net.npz', network=net)
651+
"""
652+
assert network is not None
653+
assert sess is not None
654+
if not os.path.exists(name):
655+
print("[!] Load {} failed!".format(name))
656+
return False
657+
else:
658+
params = tl.files.load_npz(name=name)
659+
tl.files.assign_params(sess, params, network)
660+
print("[*] Load {} SUCCESS!".format(name))
661+
return network
638662

639663
# Load and save variables
640664
def save_any_to_npy(save_dict={}, name='any.npy'):

0 commit comments

Comments
 (0)