Skip to content

Commit 147db84

Browse files
committed
release vis.real_images
1 parent bcdab79 commit 147db84

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed

docs/modules/visualize.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ to visualize the model, activations etc. Here we provide more functions for data
99
.. autosummary::
1010

1111
read_image
12+
read_images
1213
save_image
1314
save_images
1415
W
@@ -25,6 +26,10 @@ Read one image
2526
^^^^^^^^^^^^^^^^^
2627
.. autofunction:: read_image
2728

29+
Read multiple images
30+
^^^^^^^^^^^^^^^^^^^^^^^^^^
31+
.. autofunction:: read_images
32+
2833
Save one image
2934
^^^^^^^^^^^^^^^^^^^^^^^^^^
3035
.. autofunction:: save_image

tensorlayer/files.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def load_mnist_dataset(shape=(-1,784), path="data/mnist/"):
2828
Parameters
2929
----------
3030
shape : tuple
31-
The shape of digit images, defaults to (-1,784)
31+
The shape of digit images, defaults is (-1,784)
3232
path : string
33-
Path to download data to, defaults to data/mnist/
33+
The path that the data is downloaded to, defaults is data/mnist/
3434
3535
Examples
3636
--------
@@ -102,7 +102,7 @@ def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data/cifar10/', plotable=F
102102
second : int
103103
If ``plotable`` is True, ``second`` is the display time.
104104
path : string
105-
Path to download data to, defaults to data/cifar10/
105+
The path that the data is downloaded to, defaults is data/cifar10/
106106
107107
Examples
108108
--------
@@ -249,7 +249,7 @@ def load_ptb_dataset(path='data/ptb/'):
249249
Parameters
250250
----------
251251
path : : string
252-
Path to download data to, defaults to data/ptb/
252+
The path that the data is downloaded to, defaults is data/ptb/
253253
254254
Returns
255255
--------
@@ -302,7 +302,7 @@ def load_matt_mahoney_text8_dataset(path='data/mm_test8/'):
302302
Parameters
303303
----------
304304
path : : string
305-
Path to download data to, defaults to data/mm_test8/
305+
The path that the data is downloaded to, defaults is data/mm_test8/
306306
307307
Returns
308308
--------
@@ -336,7 +336,7 @@ def load_imdb_dataset(path='data/imdb/', nb_words=None, skip_top=0,
336336
Parameters
337337
----------
338338
path : : string
339-
Path to download data to, defaults to data/imdb/
339+
The path that the data is downloaded to, defaults is data/imdb/
340340
341341
Examples
342342
--------
@@ -419,7 +419,7 @@ def load_nietzsche_dataset(path='data/nietzsche/'):
419419
Parameters
420420
----------
421421
path : string
422-
Path to download data to, defaults to data/nietzsche/
422+
The path that the data is downloaded to, defaults is data/nietzsche/
423423
424424
Examples
425425
--------
@@ -447,7 +447,7 @@ def load_wmt_en_fr_dataset(path='data/wmt_en_fr/'):
447447
Parameters
448448
----------
449449
path : string
450-
Path to download data to, defaults to data/wmt_en_fr/
450+
The path that the data is downloaded to, defaults is data/wmt_en_fr/
451451
452452
References
453453
----------
@@ -502,16 +502,20 @@ def get_wmt_enfr_dev_set(path):
502502

503503
return train_path, dev_path
504504

505-
def load_flickr25k_dataset(tag='sky', path="data/flickr25k"):
505+
def load_flickr25k_dataset(tag='sky', path="data/flickr25k", n_threads=50, printable=False):
506506
"""Returns a list of images by a given tag from Flick25k dataset,
507507
it will download Flickr25k from `the official website <http://press.liacs.nl/mirflickr/mirdownload.html>`_
508508
at the first time you use it.
509509
510510
Parameters
511511
------------
512-
tag : string like 'dog', 'red' see `Flickr Search <https://www.flickr.com/search/>`_.
512+
tag : string or None
513+
If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search <https://www.flickr.com/search/>`_.
514+
If you want to get all images, set to ``None``.
513515
path : string
514-
Path to download data to, defaults to ``data/flickr25k/``
516+
The path that the data is downloaded to, defaults is ``data/flickr25k/``
517+
n_threads : int, number of thread to read image.
518+
printable : bool, print infomation when reading images, default is False.
515519
516520
Examples
517521
-----------
@@ -536,14 +540,18 @@ def load_flickr25k_dataset(tag='sky', path="data/flickr25k"):
536540
path_tags.sort(key=natural_keys)
537541
# print(path_tags[0:10])
538542
# 3. select images
539-
images = []
543+
if tag is None:
544+
print("[Flickr25k] reading all images")
545+
else:
546+
print("[Flickr25k] reading images with tag: {}".format(tag))
547+
images_list = []
540548
for idx in range(0, len(path_tags)):
541549
tags = read_file(folder_tags+'/'+path_tags[idx]).split('\n')
542550
# print(idx+1, tags)
543-
if tag in tags:
544-
images.append(visualize.read_image(path_imgs[idx], folder_imgs))
545-
# print(idx+1, tags)
546-
# exit()
551+
if tag is None or tag in tags:
552+
images_list.append(path_imgs[idx])
553+
554+
images = visualize.read_images(images_list, folder_imgs, n_threads=50, printable=False)
547555
return images
548556

549557
## Load and save network
@@ -903,11 +911,11 @@ def maybe_download_and_extract(filename, working_directory, url_source, extract=
903911
A folder path to search for the file in and dowload the file to
904912
url : string
905913
The URL to download the file from
906-
extract : bool, defaults to False
914+
extract : bool, defaults is False
907915
If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file
908916
expected_bytes : int/None
909917
If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception,
910-
defaults to None which corresponds to no check being performed
918+
defaults is None which corresponds to no check being performed
911919
912920
Returns
913921
----------

tensorlayer/visualize.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,41 @@
1212
import matplotlib.pyplot as plt
1313
import numpy as np
1414
import os
15+
from . import prepro
1516

1617

1718
## Save images
1819
import scipy.misc
1920

20-
def read_image(image, image_path=''):
21+
def read_image(image, path=''):
2122
""" Read one image.
2223
2324
Parameters
2425
-----------
2526
images : string, file name.
26-
image_path : string, path.
27+
path : string, path.
2728
"""
28-
return scipy.misc.imread(os.path.join(image_path, image))
29+
return scipy.misc.imread(os.path.join(path, image))
30+
31+
def read_images(img_list, path='', n_threads=10, printable=True):
32+
""" Returns all images in list by given path and name of each image file.
33+
34+
Parameters
35+
-------------
36+
img_list : list of string, the image file names.
37+
path : string, image folder path.
38+
n_threads : int, number of thread to read image.
39+
printable : bool, print infomation when reading images, default is True.
40+
"""
41+
imgs = []
42+
for idx in range(0, len(img_list), n_threads):
43+
b_imgs_list = img_list[idx : idx + n_threads]
44+
b_imgs = prepro.threading_data(b_imgs_list, fn=read_image, path=path)
45+
# print(b_imgs.shape)
46+
imgs.extend(b_imgs)
47+
if printable:
48+
print('read %d from %s' % (len(imgs), path))
49+
return imgs
2950

3051
def save_image(image, image_path=''):
3152
"""Save one image.

0 commit comments

Comments
 (0)