Skip to content

Commit 1410081

Browse files
committed
update files APIs for windows
1 parent 5eac46c commit 1410081

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

tensorlayer/files.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
## Load dataset functions
24-
def load_mnist_dataset(shape=(-1,784), path="data/mnist/"):
24+
def load_mnist_dataset(shape=(-1,784), path="data"):
2525
"""Automatically download MNIST dataset
2626
and return the training, validation and test set with 50000, 10000 and 10000
2727
digit images respectively.
@@ -38,6 +38,7 @@ def load_mnist_dataset(shape=(-1,784), path="data/mnist/"):
3838
>>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1,784))
3939
>>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
4040
"""
41+
path = os.path.join(path, 'mnist')
4142
# We first define functions for loading MNIST images and labels.
4243
# For convenience, they also download the requested files if needed.
4344
def load_mnist_images(path, filename):
@@ -84,7 +85,7 @@ def load_mnist_labels(path, filename):
8485
y_test = np.asarray(y_test, dtype=np.int32)
8586
return X_train, y_train, X_val, y_val, X_test, y_test
8687

87-
def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data/cifar10/', plotable=False, second=3):
88+
def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False, second=3):
8889
"""The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with
8990
6000 images per class. There are 50000 training images and 10000 test images.
9091
@@ -115,7 +116,7 @@ def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data/cifar10/', plotable=F
115116
- `Data download link <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_
116117
- `Code references <https://teratail.com/questions/28932>`_
117118
"""
118-
119+
path = os.path.join(path, 'cifar10')
119120
print("Load or Download cifar10 > {}".format(path))
120121

121122
#Helper function to unpickle the data
@@ -201,7 +202,7 @@ def unpickle(file):
201202

202203
return X_train, y_train, X_test, y_test
203204

204-
def load_ptb_dataset(path='data/ptb/'):
205+
def load_ptb_dataset(path='data'):
205206
"""Penn TreeBank (PTB) dataset is used in many LANGUAGE MODELING papers,
206207
including "Empirical Evaluation and Combination of Advanced Language
207208
Modeling Techniques", "Recurrent Neural Network Regularization".
@@ -226,6 +227,7 @@ def load_ptb_dataset(path='data/ptb/'):
226227
- ``tensorflow.models.rnn.ptb import reader``
227228
- `Manual download <http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz>`_
228229
"""
230+
path = os.path.join(path, 'ptb')
229231
print("Load or Download Penn TreeBank (PTB) dataset > {}".format(path))
230232

231233
#Maybe dowload and uncompress tar, or load exsisting files
@@ -252,7 +254,7 @@ def load_ptb_dataset(path='data/ptb/'):
252254
# exit()
253255
return train_data, valid_data, test_data, vocabulary
254256

255-
def load_matt_mahoney_text8_dataset(path='data/mm_test8/'):
257+
def load_matt_mahoney_text8_dataset(path='data'):
256258
"""Download a text file from Matt Mahoney's website
257259
if not present, and make sure it's the right size.
258260
Extract the first file enclosed in a zip file as a list of words.
@@ -274,7 +276,7 @@ def load_matt_mahoney_text8_dataset(path='data/mm_test8/'):
274276
>>> words = tl.files.load_matt_mahoney_text8_dataset()
275277
>>> print('Data size', len(words))
276278
"""
277-
279+
path = os.path.join(path, 'mm_test8')
278280
print("Load or Download matt_mahoney_text8 Dataset> {}".format(path))
279281

280282
filename = 'text8.zip'
@@ -287,7 +289,7 @@ def load_matt_mahoney_text8_dataset(path='data/mm_test8/'):
287289
word_list[idx] = word_list[idx].decode()
288290
return word_list
289291

290-
def load_imdb_dataset(path='data/imdb/', nb_words=None, skip_top=0,
292+
def load_imdb_dataset(path='data', nb_words=None, skip_top=0,
291293
maxlen=None, test_split=0.2, seed=113,
292294
start_char=1, oov_char=2, index_from=3):
293295
"""Load IMDB dataset
@@ -310,6 +312,7 @@ def load_imdb_dataset(path='data/imdb/', nb_words=None, skip_top=0,
310312
-----------
311313
- `Modified from keras. <https://github.com/fchollet/keras/blob/master/keras/datasets/imdb.py>`_
312314
"""
315+
path = os.path.join(path, 'imdb')
313316

314317
filename = "imdb.pkl"
315318
url = 'https://s3.amazonaws.com/text-datasets/'
@@ -623,18 +626,18 @@ def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data/cyclegan
623626
"""
624627
url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
625628

626-
if folder_exists(path+"/"+filename) is False:
629+
if folder_exists(os.path.join(path, filename)) is False:
627630
print("[*] {} is nonexistent in {}".format(filename, path))
628631
maybe_download_and_extract(filename+'.zip', path, url, extract=True)
629-
del_file(path+'/'+filename+'.zip')
632+
del_file(os.path.join(path, filename+'.zip'))
630633

631634
def load_image_from_folder(path):
632635
path_imgs = load_file_list(path=path, regx='\\.jpg', printable=False)
633636
return visualize.read_images(path_imgs, path=path, n_threads=10, printable=False)
634-
im_train_A = load_image_from_folder(path+"/"+filename+"/trainA")
635-
im_train_B = load_image_from_folder(path+"/"+filename+"/trainB")
636-
im_test_A = load_image_from_folder(path+"/"+filename+"/testA")
637-
im_test_B = load_image_from_folder(path+"/"+filename+"/testB")
637+
im_train_A = load_image_from_folder(os.path.join(path, filename, "trainA"))
638+
im_train_B = load_image_from_folder(os.path.join(path, filename, "trainB"))
639+
im_test_A = load_image_from_folder(os.path.join(path, filename, "testA"))
640+
im_test_B = load_image_from_folder(os.path.join(path, filename, "testB"))
638641

639642
def if_2d_to_3d(images): # [h, w] --> [h, w, 3]
640643
for i in range(len(images)):
@@ -819,15 +822,22 @@ def _recursive_parse_xml_to_dict(xml):
819822
raise Exception("Please set the dataset aug to either 2012 or 2007.")
820823

821824
##======== download dataset
822-
if folder_exists(path+"/"+extracted_filename) is False:
825+
from sys import platform as _platform
826+
if folder_exists(os.path.join(path, extracted_filename)) is False:
823827
print("[VOC] {} is nonexistent in {}".format(extracted_filename, path))
824828
maybe_download_and_extract(tar_filename, path, url, extract=True)
825-
del_file(path+'/'+tar_filename)
829+
del_file(os.path.join(path, tar_filename))
826830
if dataset == "2012":
827-
os.system("mv {}/VOCdevkit/VOC2012 {}/VOC2012".format(path, path))
831+
if _platform == "win32":
832+
os.system("mv {}\VOCdevkit\VOC2012 {}\VOC2012".format(path, path))
833+
else:
834+
os.system("mv {}/VOCdevkit/VOC2012 {}/VOC2012".format(path, path))
828835
elif dataset == "2007":
829-
os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007".format(path, path))
830-
del_folder(path+'/VOCdevkit')
836+
if _platform == "win32":
837+
os.system("mv {}\VOCdevkit\VOC2007 {}\VOC2007".format(path, path))
838+
else:
839+
os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007".format(path, path))
840+
del_folder(os.path.join(path, 'VOCdevkit'))
831841
##======== object classes(labels) NOTE: YOU CAN CUSTOMIZE THIS LIST
832842
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car",
833843
"cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike",
@@ -848,31 +858,31 @@ def _recursive_parse_xml_to_dict(xml):
848858
imgs_file_list = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False)
849859
print("[VOC] {} images found".format(len(imgs_file_list)))
850860
imgs_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000027.jpg --> 2007000027
851-
imgs_file_list = [folder_imgs+s for s in imgs_file_list]
861+
imgs_file_list = [os.path.join(folder_imgs, s) for s in imgs_file_list]
852862
# print('IM',imgs_file_list[0::3333], imgs_file_list[-1])
853863
##======== 2. semantic segmentation maps path list
854864
# folder_semseg = path+"/"+extracted_filename+"/SegmentationClass/"
855865
folder_semseg = os.path.join(path, extracted_filename, "SegmentationClass")
856866
imgs_semseg_file_list = load_file_list(path=folder_semseg, regx='\\.png', printable=False)
857867
print("[VOC] {} maps for semantic segmentation found".format(len(imgs_semseg_file_list)))
858868
imgs_semseg_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000032.png --> 2007000032
859-
imgs_semseg_file_list = [folder_semseg+s for s in imgs_semseg_file_list]
869+
imgs_semseg_file_list = [os.path.join(folder_semseg, s) for s in imgs_semseg_file_list]
860870
# print('Semantic Seg IM',imgs_semseg_file_list[0::333], imgs_semseg_file_list[-1])
861871
##======== 3. instance segmentation maps path list
862872
# folder_insseg = path+"/"+extracted_filename+"/SegmentationObject/"
863873
folder_insseg = os.path.join(path, extracted_filename, "SegmentationObject")
864874
imgs_insseg_file_list = load_file_list(path=folder_insseg, regx='\\.png', printable=False)
865875
print("[VOC] {} maps for instance segmentation found".format(len(imgs_semseg_file_list)))
866876
imgs_insseg_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000032.png --> 2007000032
867-
imgs_insseg_file_list = [folder_semseg+s for s in imgs_insseg_file_list]
877+
imgs_insseg_file_list = [os.path.join(folder_insseg, s) for s in imgs_insseg_file_list]
868878
# print('Instance Seg IM',imgs_insseg_file_list[0::333], imgs_insseg_file_list[-1])
869879
##======== 4. annotations for bounding box and object class
870880
# folder_ann = path+"/"+extracted_filename+"/Annotations/"
871881
folder_ann = os.path.join(path, extracted_filename, "Annotations")
872882
imgs_ann_file_list = load_file_list(path=folder_ann, regx='\\.xml', printable=False)
873883
print("[VOC] {} XML annotation files for bounding box and object class found".format(len(imgs_ann_file_list)))
874884
imgs_ann_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000027.xml --> 2007000027
875-
imgs_ann_file_list = [folder_ann+s for s in imgs_ann_file_list]
885+
imgs_ann_file_list = [os.path.join(folder_ann, s) for s in imgs_ann_file_list]
876886
# print('ANN',imgs_ann_file_list[0::3333], imgs_ann_file_list[-1])
877887
##======== parse XML annotations
878888
def convert(size, box):

0 commit comments

Comments
 (0)