Skip to content

Commit 8b2c382

Browse files
committed
move pyplot into functions
1 parent 0e6bf29 commit 8b2c382

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tensorlayer/visualize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
# matplotlib.use('Agg')
1111

12-
import matplotlib.pyplot as plt
1312
import numpy as np
1413
import os
1514
from . import prepro
@@ -114,6 +113,7 @@ def W(W=None, second=10, saveable=True, shape=[28,28], name='mnist', fig_idx=239
114113
--------
115114
>>> tl.visualize.W(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012)
116115
"""
116+
import matplotlib.pyplot as plt
117117
if saveable is False:
118118
plt.ion()
119119
fig = plt.figure(fig_idx) # show all feature images
@@ -177,6 +177,7 @@ def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=1283
177177
>>> observation = env.reset()
178178
>>> tl.visualize.frame(observation)
179179
"""
180+
import matplotlib.pyplot as plt
180181
if saveable is False:
181182
plt.ion()
182183
fig = plt.figure(fig_idx) # show all feature images
@@ -215,6 +216,7 @@ def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362):
215216
--------
216217
>>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012)
217218
"""
219+
import matplotlib.pyplot as plt
218220
# print(CNN.shape) # (5, 5, 3, 64)
219221
# exit()
220222
n_mask = CNN.shape[3]
@@ -280,6 +282,7 @@ def images2d(images=None, second=10, saveable=True, name='images', dtype=None,
280282
>>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
281283
>>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212)
282284
"""
285+
import matplotlib.pyplot as plt
283286
# print(images.shape) # (50000, 32, 32, 3)
284287
# exit()
285288
if dtype:
@@ -350,6 +353,7 @@ def tsne_embedding(embeddings, reverse_dictionary, plot_only=500,
350353
>>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary,
351354
... plot_only=500, second=5, saveable=False, name='tsne')
352355
"""
356+
import matplotlib.pyplot as plt
353357
def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5,
354358
saveable=True, name='tsne', fig_idx=9862):
355359
assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"

0 commit comments

Comments
 (0)