Skip to content

Commit 29aaec5

Browse files
Merge branch 'master' into patch-1
2 parents 598c13a + 61ef08c commit 29aaec5

File tree

4 files changed

+34
-40
lines changed

4 files changed

+34
-40
lines changed

docs/modules/files.rst

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
API - Files
22
===================================
33

4+
A collections of helper functions to work with dataset.
5+
Load benchmark dataset, save and restore model, save and load variables.
46

57
.. automodule:: tensorlayer.files
68

@@ -50,6 +52,7 @@ API - Files
5052

5153
npz_to_W_pdf
5254

55+
5356
Load dataset functions
5457
------------------------
5558

@@ -109,7 +112,7 @@ VOC 2007/2012
109112
^^^^^^^^^^^^^^^^
110113
.. autofunction:: load_voc_dataset
111114

112-
MPII
115+
MPII
113116
^^^^^^^^^^^^^^^^
114117
.. autofunction:: load_mpii_pose_dataset
115118

@@ -159,6 +162,34 @@ Load network from ckpt
159162
Load and save variables
160163
------------------------
161164

165+
TensorFlow provides ``.ckpt`` file format to save and restore the models, while
166+
we suggest to use standard python file format ``.npz`` to save models for the
167+
sake of cross-platform.
168+
169+
.. code-block:: python
170+
171+
## save model as .ckpt
172+
saver = tf.train.Saver()
173+
save_path = saver.save(sess, "model.ckpt")
174+
# restore model from .ckpt
175+
saver = tf.train.Saver()
176+
saver.restore(sess, "model.ckpt")
177+
178+
## save model as .npz
179+
tl.files.save_npz(network.all_params , name='model.npz')
180+
# restore model from .npz (method 1)
181+
load_params = tl.files.load_npz(name='model.npz')
182+
tl.files.assign_params(sess, load_params, network)
183+
# restore model from .npz (method 2)
184+
tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)
185+
186+
## you can assign the pre-trained parameters as follow
187+
# 1st parameter
188+
tl.files.assign_params(sess, [load_params[0]], network)
189+
# the first three parameters
190+
tl.files.assign_params(sess, load_params[:3], network)
191+
192+
162193
Save variables as .npy
163194
^^^^^^^^^^^^^^^^^^^^^^^^^
164195
.. autofunction:: save_any_to_npy

tensorlayer/db.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
3-
"""
4-
Experimental Database Management System.
5-
6-
Latest Version
7-
"""
83

94
import inspect
105
import pickle

tensorlayer/files.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,4 @@
11
# -*- coding: utf-8 -*-
2-
"""
3-
A collections of helper functions to work with dataset.
4-
5-
Load benchmark dataset, save and restore model, save and load variables.
6-
TensorFlow provides ``.ckpt`` file format to save and restore the models, while
7-
we suggest to use standard python file format ``.npz`` to save models for the
8-
sake of cross-platform.
9-
10-
.. code-block:: python
11-
12-
## save model as .ckpt
13-
saver = tf.train.Saver()
14-
save_path = saver.save(sess, "model.ckpt")
15-
# restore model from .ckpt
16-
saver = tf.train.Saver()
17-
saver.restore(sess, "model.ckpt")
18-
19-
## save model as .npz
20-
tl.files.save_npz(network.all_params , name='model.npz')
21-
# restore model from .npz (method 1)
22-
load_params = tl.files.load_npz(name='model.npz')
23-
tl.files.assign_params(sess, load_params, network)
24-
# restore model from .npz (method 2)
25-
tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)
26-
27-
## you can assign the pre-trained parameters as follow
28-
# 1st parameter
29-
tl.files.assign_params(sess, [load_params[0]], network)
30-
# the first three parameters
31-
tl.files.assign_params(sess, load_params[:3], network)
32-
33-
"""
342

353
import gzip
364
import math
@@ -1345,7 +1313,7 @@ def load_mpii_pose_dataset(path='data', is_16_pos_only=False):
13451313
>>> import tensorlayer as tl
13461314
>>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
13471315
>>> image = tl.vis.read_image(img_train_list[0])
1348-
>>> tl.vis.draw_mpii_people_to_image(image, ann_train_list[0], 'image.png')
1316+
>>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
13491317
>>> pprint.pprint(ann_train_list[0])
13501318
13511319
References

tensorlayer/visualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def draw_mpii_pose_to_image(image, poses, save_name='image.png'):
247247
>>> import tensorlayer as tl
248248
>>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
249249
>>> image = tl.vis.read_image(img_train_list[0])
250-
>>> tl.vis.draw_mpii_people_to_image(image, ann_train_list[0], 'image.png')
250+
>>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
251251
>>> pprint.pprint(ann_train_list[0])
252252
253253
References

0 commit comments

Comments
 (0)