Skip to content

Commit a686af5

Browse files
committed
update mnist example for vis
1 parent 05fe3e3 commit a686af5

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

example/tutorial_mnist.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,13 @@ def main_test_layers(model='relu'):
151151
print(" val acc: %f" % (val_acc/ n_batch))
152152
try:
153153
# You can visualize the weight of 1st hidden layer as follow.
154-
tl.visualize.W(network.all_params[0].eval(), second=10,
154+
tl.vis.W(network.all_params[0].eval(), second=10,
155155
saveable=True, shape=[28, 28],
156156
name='w1_'+str(epoch+1), fig_idx=2012)
157157
# You can also save the weight of 1st hidden layer to .npz file.
158158
# tl.files.save_npz([network.all_params[0]] , name='w1'+str(epoch+1)+'.npz')
159159
except:
160-
raise Exception("You should change visualize_W(), if you want \
161-
to save the feature images for different dataset")
160+
print("You should change vis.W(), if you want to save the feature images for different dataset")
162161

163162
print('Evaluation')
164163
test_loss, test_acc, n_batch = 0, 0, 0
@@ -390,11 +389,11 @@ def main_test_stacked_denoise_AE(model='relu'):
390389
print(" val acc: %f" % (val_acc/ n_batch))
391390
try:
392391
# visualize the 1st hidden layer during fine-tune
393-
tl.visualize.W(network.all_params[0].eval(), second=10,
392+
tl.vis.W(network.all_params[0].eval(), second=10,
394393
saveable=True, shape=[28, 28],
395394
name='w1_'+str(epoch+1), fig_idx=2012)
396395
except:
397-
raise Exception("# You should change visualize_W(), if you want to save the feature images for different dataset")
396+
print("You should change vis.W(), if you want to save the feature images for different dataset")
398397

399398
print('Evaluation')
400399
test_loss, test_acc, n_batch = 0, 0, 0
@@ -558,11 +557,11 @@ def main_test_cnn_layer():
558557
print(" val loss: %f" % (val_loss/ n_batch))
559558
print(" val acc: %f" % (val_acc/ n_batch))
560559
try:
561-
tl.visualize.CNN2d(network.all_params[0].eval(),
560+
tl.vis.CNN2d(network.all_params[0].eval(),
562561
second=10, saveable=True,
563562
name='cnn1_'+str(epoch+1), fig_idx=2012)
564563
except:
565-
raise Exception("# You should change visualize.CNN(), if you want to save the feature images for different dataset")
564+
print("You should change vis.CNN(), if you want to save the feature images for different dataset")
566565

567566
print('Evaluation')
568567
test_loss, test_acc, n_batch = 0, 0, 0

0 commit comments

Comments
 (0)