|
9 | 9 |
|
10 | 10 | # matplotlib.use('Agg') |
11 | 11 |
|
12 | | -import matplotlib.pyplot as plt |
13 | 12 | import numpy as np |
14 | 13 | import os |
15 | 14 | from . import prepro |
@@ -114,6 +113,7 @@ def W(W=None, second=10, saveable=True, shape=[28,28], name='mnist', fig_idx=239 |
114 | 113 | -------- |
115 | 114 | >>> tl.visualize.W(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012) |
116 | 115 | """ |
| 116 | + import matplotlib.pyplot as plt |
117 | 117 | if saveable is False: |
118 | 118 | plt.ion() |
119 | 119 | fig = plt.figure(fig_idx) # show all feature images |
@@ -177,6 +177,7 @@ def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=1283 |
177 | 177 | >>> observation = env.reset() |
178 | 178 | >>> tl.visualize.frame(observation) |
179 | 179 | """ |
| 180 | + import matplotlib.pyplot as plt |
180 | 181 | if saveable is False: |
181 | 182 | plt.ion() |
182 | 183 | fig = plt.figure(fig_idx) # show all feature images |
@@ -215,6 +216,7 @@ def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362): |
215 | 216 | -------- |
216 | 217 | >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012) |
217 | 218 | """ |
| 219 | + import matplotlib.pyplot as plt |
218 | 220 | # print(CNN.shape) # (5, 5, 3, 64) |
219 | 221 | # exit() |
220 | 222 | n_mask = CNN.shape[3] |
@@ -280,6 +282,7 @@ def images2d(images=None, second=10, saveable=True, name='images', dtype=None, |
280 | 282 | >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) |
281 | 283 | >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212) |
282 | 284 | """ |
| 285 | + import matplotlib.pyplot as plt |
283 | 286 | # print(images.shape) # (50000, 32, 32, 3) |
284 | 287 | # exit() |
285 | 288 | if dtype: |
@@ -350,6 +353,7 @@ def tsne_embedding(embeddings, reverse_dictionary, plot_only=500, |
350 | 353 | >>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary, |
351 | 354 | ... plot_only=500, second=5, saveable=False, name='tsne') |
352 | 355 | """ |
| 356 | + import matplotlib.pyplot as plt |
353 | 357 | def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5, |
354 | 358 | saveable=True, name='tsne', fig_idx=9862): |
355 | 359 | assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings" |
|
0 commit comments