2121
2222
2323## Load dataset functions
24- def load_mnist_dataset (shape = (- 1 ,784 ), path = "data/mnist/ " ):
24+ def load_mnist_dataset (shape = (- 1 ,784 ), path = "data" ):
2525 """Automatically download MNIST dataset
2626 and return the training, validation and test set with 50000, 10000 and 10000
2727 digit images respectively.
@@ -38,6 +38,7 @@ def load_mnist_dataset(shape=(-1,784), path="data/mnist/"):
3838 >>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1,784))
3939 >>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
4040 """
41+ path = os .path .join (path , 'mnist' )
4142 # We first define functions for loading MNIST images and labels.
4243 # For convenience, they also download the requested files if needed.
4344 def load_mnist_images (path , filename ):
@@ -84,7 +85,7 @@ def load_mnist_labels(path, filename):
8485 y_test = np .asarray (y_test , dtype = np .int32 )
8586 return X_train , y_train , X_val , y_val , X_test , y_test
8687
87- def load_cifar10_dataset (shape = (- 1 , 32 , 32 , 3 ), path = 'data/cifar10/ ' , plotable = False , second = 3 ):
88+ def load_cifar10_dataset (shape = (- 1 , 32 , 32 , 3 ), path = 'data' , plotable = False , second = 3 ):
8889 """The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with
8990 6000 images per class. There are 50000 training images and 10000 test images.
9091
@@ -115,7 +116,7 @@ def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data/cifar10/', plotable=F
115116 - `Data download link <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_
116117 - `Code references <https://teratail.com/questions/28932>`_
117118 """
118-
119+ path = os . path . join ( path , 'cifar10' )
119120 print ("Load or Download cifar10 > {}" .format (path ))
120121
121122 #Helper function to unpickle the data
@@ -201,7 +202,7 @@ def unpickle(file):
201202
202203 return X_train , y_train , X_test , y_test
203204
204- def load_ptb_dataset (path = 'data/ptb/ ' ):
205+ def load_ptb_dataset (path = 'data' ):
205206 """Penn TreeBank (PTB) dataset is used in many LANGUAGE MODELING papers,
206207 including "Empirical Evaluation and Combination of Advanced Language
207208 Modeling Techniques", "Recurrent Neural Network Regularization".
@@ -226,6 +227,7 @@ def load_ptb_dataset(path='data/ptb/'):
226227 - ``tensorflow.models.rnn.ptb import reader``
227228 - `Manual download <http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz>`_
228229 """
230+ path = os .path .join (path , 'ptb' )
229231 print ("Load or Download Penn TreeBank (PTB) dataset > {}" .format (path ))
230232
231233 #Maybe dowload and uncompress tar, or load exsisting files
@@ -252,7 +254,7 @@ def load_ptb_dataset(path='data/ptb/'):
252254 # exit()
253255 return train_data , valid_data , test_data , vocabulary
254256
255- def load_matt_mahoney_text8_dataset (path = 'data/mm_test8/ ' ):
257+ def load_matt_mahoney_text8_dataset (path = 'data' ):
256258 """Download a text file from Matt Mahoney's website
257259 if not present, and make sure it's the right size.
258260 Extract the first file enclosed in a zip file as a list of words.
@@ -274,7 +276,7 @@ def load_matt_mahoney_text8_dataset(path='data/mm_test8/'):
274276 >>> words = tl.files.load_matt_mahoney_text8_dataset()
275277 >>> print('Data size', len(words))
276278 """
277-
279+ path = os . path . join ( path , 'mm_test8' )
278280 print ("Load or Download matt_mahoney_text8 Dataset> {}" .format (path ))
279281
280282 filename = 'text8.zip'
@@ -287,7 +289,7 @@ def load_matt_mahoney_text8_dataset(path='data/mm_test8/'):
287289 word_list [idx ] = word_list [idx ].decode ()
288290 return word_list
289291
290- def load_imdb_dataset (path = 'data/imdb/ ' , nb_words = None , skip_top = 0 ,
292+ def load_imdb_dataset (path = 'data' , nb_words = None , skip_top = 0 ,
291293 maxlen = None , test_split = 0.2 , seed = 113 ,
292294 start_char = 1 , oov_char = 2 , index_from = 3 ):
293295 """Load IMDB dataset
@@ -310,6 +312,7 @@ def load_imdb_dataset(path='data/imdb/', nb_words=None, skip_top=0,
310312 -----------
311313 - `Modified from keras. <https://github.com/fchollet/keras/blob/master/keras/datasets/imdb.py>`_
312314 """
315+ path = os .path .join (path , 'imdb' )
313316
314317 filename = "imdb.pkl"
315318 url = 'https://s3.amazonaws.com/text-datasets/'
@@ -623,18 +626,18 @@ def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data/cyclegan
623626 """
624627 url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
625628
626- if folder_exists (path + "/" + filename ) is False :
629+ if folder_exists (os . path . join ( path , filename ) ) is False :
627630 print ("[*] {} is nonexistent in {}" .format (filename , path ))
628631 maybe_download_and_extract (filename + '.zip' , path , url , extract = True )
629- del_file (path + '/' + filename + '.zip' )
632+ del_file (os . path . join ( path , filename + '.zip' ) )
630633
631634 def load_image_from_folder (path ):
632635 path_imgs = load_file_list (path = path , regx = '\\ .jpg' , printable = False )
633636 return visualize .read_images (path_imgs , path = path , n_threads = 10 , printable = False )
634- im_train_A = load_image_from_folder (path + "/" + filename + "/ trainA" )
635- im_train_B = load_image_from_folder (path + "/" + filename + "/ trainB" )
636- im_test_A = load_image_from_folder (path + "/" + filename + "/ testA" )
637- im_test_B = load_image_from_folder (path + "/" + filename + "/ testB" )
637+ im_train_A = load_image_from_folder (os . path . join ( path , filename , " trainA") )
638+ im_train_B = load_image_from_folder (os . path . join ( path , filename , " trainB") )
639+ im_test_A = load_image_from_folder (os . path . join ( path , filename , " testA") )
640+ im_test_B = load_image_from_folder (os . path . join ( path , filename , " testB") )
638641
639642 def if_2d_to_3d (images ): # [h, w] --> [h, w, 3]
640643 for i in range (len (images )):
@@ -819,15 +822,22 @@ def _recursive_parse_xml_to_dict(xml):
819822 raise Exception ("Please set the dataset aug to either 2012 or 2007." )
820823
821824 ##======== download dataset
822- if folder_exists (path + "/" + extracted_filename ) is False :
825+ from sys import platform as _platform
826+ if folder_exists (os .path .join (path , extracted_filename )) is False :
823827 print ("[VOC] {} is nonexistent in {}" .format (extracted_filename , path ))
824828 maybe_download_and_extract (tar_filename , path , url , extract = True )
825- del_file (path + '/' + tar_filename )
829+ del_file (os . path . join ( path , tar_filename ) )
826830 if dataset == "2012" :
827- os .system ("mv {}/VOCdevkit/VOC2012 {}/VOC2012" .format (path , path ))
831+ if _platform == "win32" :
832+ os .system ("mv {}\VOCdevkit\VOC2012 {}\VOC2012" .format (path , path ))
833+ else :
834+ os .system ("mv {}/VOCdevkit/VOC2012 {}/VOC2012" .format (path , path ))
828835 elif dataset == "2007" :
829- os .system ("mv {}/VOCdevkit/VOC2007 {}/VOC2007" .format (path , path ))
830- del_folder (path + '/VOCdevkit' )
836+ if _platform == "win32" :
837+ os .system ("mv {}\VOCdevkit\VOC2007 {}\VOC2007" .format (path , path ))
838+ else :
839+ os .system ("mv {}/VOCdevkit/VOC2007 {}/VOC2007" .format (path , path ))
840+ del_folder (os .path .join (path , 'VOCdevkit' ))
831841 ##======== object classes(labels) NOTE: YOU CAN CUSTOMIZE THIS LIST
832842 classes = ["aeroplane" , "bicycle" , "bird" , "boat" , "bottle" , "bus" , "car" ,
833843 "cat" , "chair" , "cow" , "diningtable" , "dog" , "horse" , "motorbike" ,
@@ -848,31 +858,31 @@ def _recursive_parse_xml_to_dict(xml):
848858 imgs_file_list = load_file_list (path = folder_imgs , regx = '\\ .jpg' , printable = False )
849859 print ("[VOC] {} images found" .format (len (imgs_file_list )))
850860 imgs_file_list .sort (key = lambda s : int (s .replace ('.' ,' ' ).replace ('_' , '' ).split (' ' )[- 2 ])) # 2007_000027.jpg --> 2007000027
851- imgs_file_list = [folder_imgs + s for s in imgs_file_list ]
861+ imgs_file_list = [os . path . join ( folder_imgs , s ) for s in imgs_file_list ]
852862 # print('IM',imgs_file_list[0::3333], imgs_file_list[-1])
853863 ##======== 2. semantic segmentation maps path list
854864 # folder_semseg = path+"/"+extracted_filename+"/SegmentationClass/"
855865 folder_semseg = os .path .join (path , extracted_filename , "SegmentationClass" )
856866 imgs_semseg_file_list = load_file_list (path = folder_semseg , regx = '\\ .png' , printable = False )
857867 print ("[VOC] {} maps for semantic segmentation found" .format (len (imgs_semseg_file_list )))
858868 imgs_semseg_file_list .sort (key = lambda s : int (s .replace ('.' ,' ' ).replace ('_' , '' ).split (' ' )[- 2 ])) # 2007_000032.png --> 2007000032
859- imgs_semseg_file_list = [folder_semseg + s for s in imgs_semseg_file_list ]
869+ imgs_semseg_file_list = [os . path . join ( folder_semseg , s ) for s in imgs_semseg_file_list ]
860870 # print('Semantic Seg IM',imgs_semseg_file_list[0::333], imgs_semseg_file_list[-1])
861871 ##======== 3. instance segmentation maps path list
862872 # folder_insseg = path+"/"+extracted_filename+"/SegmentationObject/"
863873 folder_insseg = os .path .join (path , extracted_filename , "SegmentationObject" )
864874 imgs_insseg_file_list = load_file_list (path = folder_insseg , regx = '\\ .png' , printable = False )
865875 print ("[VOC] {} maps for instance segmentation found" .format (len (imgs_semseg_file_list )))
866876 imgs_insseg_file_list .sort (key = lambda s : int (s .replace ('.' ,' ' ).replace ('_' , '' ).split (' ' )[- 2 ])) # 2007_000032.png --> 2007000032
867- imgs_insseg_file_list = [folder_semseg + s for s in imgs_insseg_file_list ]
877+ imgs_insseg_file_list = [os . path . join ( folder_insseg , s ) for s in imgs_insseg_file_list ]
868878 # print('Instance Seg IM',imgs_insseg_file_list[0::333], imgs_insseg_file_list[-1])
869879 ##======== 4. annotations for bounding box and object class
870880 # folder_ann = path+"/"+extracted_filename+"/Annotations/"
871881 folder_ann = os .path .join (path , extracted_filename , "Annotations" )
872882 imgs_ann_file_list = load_file_list (path = folder_ann , regx = '\\ .xml' , printable = False )
873883 print ("[VOC] {} XML annotation files for bounding box and object class found" .format (len (imgs_ann_file_list )))
874884 imgs_ann_file_list .sort (key = lambda s : int (s .replace ('.' ,' ' ).replace ('_' , '' ).split (' ' )[- 2 ])) # 2007_000027.xml --> 2007000027
875- imgs_ann_file_list = [folder_ann + s for s in imgs_ann_file_list ]
885+ imgs_ann_file_list = [os . path . join ( folder_ann , s ) for s in imgs_ann_file_list ]
876886 # print('ANN',imgs_ann_file_list[0::3333], imgs_ann_file_list[-1])
877887 ##======== parse XML annotations
878888 def convert (size , box ):
0 commit comments