Skip to content

Commit 813762b

Browse files
zsdonghaoluomai
authored andcommitted
update docs save model; (#490)
1 parent a4a501c commit 813762b

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

docs/modules/files.rst

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,40 @@ Google Drive
120120
^^^^^^^^^^^^^^^^
121121
.. autofunction:: download_file_from_google_drive
122122

123+
124+
125+
126+
123127
Load and save network
124128
----------------------
125129

130+
TensorFlow provides ``.ckpt`` file format to save and restore the models, while
131+
we suggest to use standard python file format ``.npz`` to save models for the
132+
sake of cross-platform.
133+
134+
.. code-block:: python
135+
136+
## save model as .ckpt
137+
saver = tf.train.Saver()
138+
save_path = saver.save(sess, "model.ckpt")
139+
# restore model from .ckpt
140+
saver = tf.train.Saver()
141+
saver.restore(sess, "model.ckpt")
142+
143+
## save model as .npz
144+
tl.files.save_npz(network.all_params , name='model.npz')
145+
# restore model from .npz (method 1)
146+
load_params = tl.files.load_npz(name='model.npz')
147+
tl.files.assign_params(sess, load_params, network)
148+
# restore model from .npz (method 2)
149+
tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)
150+
151+
## you can assign the pre-trained parameters as follow
152+
# 1st parameter
153+
tl.files.assign_params(sess, [load_params[0]], network)
154+
# the first three parameters
155+
tl.files.assign_params(sess, load_params[:3], network)
156+
126157
Save network into list (npz)
127158
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
128159
.. autofunction:: save_npz
@@ -159,37 +190,10 @@ Load network from ckpt
159190

160191

161192

193+
162194
Load and save variables
163195
------------------------
164196

165-
TensorFlow provides ``.ckpt`` file format to save and restore the models, while
166-
we suggest to use standard python file format ``.npz`` to save models for the
167-
sake of cross-platform.
168-
169-
.. code-block:: python
170-
171-
## save model as .ckpt
172-
saver = tf.train.Saver()
173-
save_path = saver.save(sess, "model.ckpt")
174-
# restore model from .ckpt
175-
saver = tf.train.Saver()
176-
saver.restore(sess, "model.ckpt")
177-
178-
## save model as .npz
179-
tl.files.save_npz(network.all_params , name='model.npz')
180-
# restore model from .npz (method 1)
181-
load_params = tl.files.load_npz(name='model.npz')
182-
tl.files.assign_params(sess, load_params, network)
183-
# restore model from .npz (method 2)
184-
tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)
185-
186-
## you can assign the pre-trained parameters as follow
187-
# 1st parameter
188-
tl.files.assign_params(sess, [load_params[0]], network)
189-
# the first three parameters
190-
tl.files.assign_params(sess, load_params[:3], network)
191-
192-
193197
Save variables as .npy
194198
^^^^^^^^^^^^^^^^^^^^^^^^^
195199
.. autofunction:: save_any_to_npy
@@ -199,6 +203,8 @@ Load variables from .npy
199203
.. autofunction:: load_npy_to_any
200204

201205

206+
207+
202208
Folder/File functions
203209
------------------------
204210

@@ -238,6 +244,8 @@ Download or extract
238244
^^^^^^^^^^^^^^^^^^^^^^^^^
239245
.. autofunction:: maybe_download_and_extract
240246

247+
248+
241249
Sort
242250
-------
243251

0 commit comments

Comments
 (0)