3939import sys
4040import tarfile
4141import zipfile
42+ import time
4243
4344import numpy as np
4445import tensorflow as tf
@@ -320,6 +321,106 @@ def unpickle(file):
320321 return X_train , y_train , X_test , y_test
321322
322323
324+ def load_cropped_svhn (path = 'data' , include_extra = True ):
325+ """Load Cropped SVHN.
326+
327+ The Cropped Street View House Numbers (SVHN) Dataset contains 32x32x3 RGB images.
328+ Digit '1' has label 1, '9' has label 9 and '0' has label 0 (the original dataset uses 10 to represent '0'), see `ufldl website <http://ufldl.stanford.edu/housenumbers/>`__.
329+
330+ Parameters
331+ ----------
332+ path : str
333+ The path that the data is downloaded to.
334+ include_extra : boolean
335+ If True (default), add extra images to the training set.
336+
337+ Returns
338+ -------
339+ X_train, y_train, X_test, y_test: tuple
340+ Return splitted training/test set respectively.
341+
342+ Examples
343+ ---------
344+ >>> X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False)
345+ >>> tl.vis.save_images(X_train[0:100], [10, 10], 'svhn.png')
346+
347+ """
348+
349+ import scipy .io
350+
351+ start_time = time .time ()
352+
353+ path = os .path .join (path , 'cropped_svhn' )
354+ logging .info ("Load or Download Cropped SVHN > {} | include extra images: {}" .format (path , include_extra ))
355+ url = "http://ufldl.stanford.edu/housenumbers/"
356+
357+ np_file = os .path .join (path , "train_32x32.npz" )
358+ if file_exists (np_file ) is False :
359+ filename = "train_32x32.mat"
360+ filepath = maybe_download_and_extract (filename , path , url )
361+ mat = scipy .io .loadmat (filepath )
362+ X_train = mat ['X' ] / 255.0 # to [0, 1]
363+ X_train = np .transpose (X_train , (3 , 0 , 1 , 2 ))
364+ y_train = np .squeeze (mat ['y' ], axis = 1 )
365+ y_train [y_train == 10 ] = 0 # replace 10 to 0
366+ np .savez (np_file , X = X_train , y = y_train )
367+ del_file (filepath )
368+ else :
369+ v = np .load (np_file )
370+ X_train = v ['X' ]
371+ y_train = v ['y' ]
372+ logging .info (" n_train: {}" .format (len (y_train )))
373+
374+ np_file = os .path .join (path , "test_32x32.npz" )
375+ if file_exists (np_file ) is False :
376+ filename = "test_32x32.mat"
377+ filepath = maybe_download_and_extract (filename , path , url )
378+ mat = scipy .io .loadmat (filepath )
379+ X_test = mat ['X' ] / 255.0
380+ X_test = np .transpose (X_test , (3 , 0 , 1 , 2 ))
381+ y_test = np .squeeze (mat ['y' ], axis = 1 )
382+ y_test [y_test == 10 ] = 0
383+ np .savez (np_file , X = X_test , y = y_test )
384+ del_file (filepath )
385+ else :
386+ v = np .load (np_file )
387+ X_test = v ['X' ]
388+ y_test = v ['y' ]
389+ logging .info (" n_test: {}" .format (len (y_test )))
390+
391+ if include_extra :
392+ logging .info (" getting extra 531131 images, please wait ..." )
393+ np_file = os .path .join (path , "extra_32x32.npz" )
394+ if file_exists (np_file ) is False :
395+ logging .info (" the first time to load extra images will take long time to convert the file format ..." )
396+ filename = "extra_32x32.mat"
397+ filepath = maybe_download_and_extract (filename , path , url )
398+ mat = scipy .io .loadmat (filepath )
399+ X_extra = mat ['X' ] / 255.0
400+ X_extra = np .transpose (X_extra , (3 , 0 , 1 , 2 ))
401+ y_extra = np .squeeze (mat ['y' ], axis = 1 )
402+ y_extra [y_extra == 10 ] = 0
403+ np .savez (np_file , X = X_extra , y = y_extra )
404+ del_file (filepath )
405+ else :
406+ v = np .load (np_file )
407+ X_extra = v ['X' ]
408+ y_extra = v ['y' ]
409+ # print(X_train.shape, X_extra.shape)
410+ logging .info (" adding n_extra {} to n_train {}" .format (len (y_extra ), len (y_train )))
411+ t = time .time ()
412+ X_train = np .concatenate ((X_train , X_extra ), 0 )
413+ y_train = np .concatenate ((y_train , y_extra ), 0 )
414+ # X_train = np.append(X_train, X_extra, axis=0)
415+ # y_train = np.append(y_train, y_extra, axis=0)
416+ logging .info (" added n_extra {} to n_train {} took {}s" .format (len (y_extra ), len (y_train ), time .time () - t ))
417+ else :
418+ logging .info (" no extra images are included" )
419+ logging .info (" image size:%s n_train:%d n_test:%d" % (str (X_train .shape [1 :4 ]), len (y_train ), len (y_test )))
420+ logging .info (" took: {}s" .format (int (time .time () - start_time )))
421+ return X_train , y_train , X_test , y_test
422+
423+
323424def load_ptb_dataset (path = 'data' ):
324425 """Load Penn TreeBank (PTB) dataset.
325426
@@ -656,19 +757,19 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
656757 url = 'http://press.liacs.nl/mirflickr/mirflickr25k/'
657758
658759 # download dataset
659- if folder_exists (path + "/ mirflickr" ) is False :
760+ if folder_exists (os . path . join ( path , " mirflickr") ) is False :
660761 logging .info ("[*] Flickr25k is nonexistent in {}" .format (path ))
661762 maybe_download_and_extract (filename , path , url , extract = True )
662- del_file (path + '/' + filename )
763+ del_file (os . path . join ( path , filename ) )
663764
664765 # return images by the given tag.
665766 # 1. image path list
666- folder_imgs = path + "/ mirflickr"
767+ folder_imgs = os . path . join ( path , " mirflickr")
667768 path_imgs = load_file_list (path = folder_imgs , regx = '\\ .jpg' , printable = False )
668769 path_imgs .sort (key = natural_keys )
669770
670771 # 2. tag path list
671- folder_tags = path + "/ mirflickr/ meta/ tags"
772+ folder_tags = os . path . join ( path , " mirflickr" , " meta" , " tags")
672773 path_tags = load_file_list (path = folder_tags , regx = '\\ .txt' , printable = False )
673774 path_tags .sort (key = natural_keys )
674775
@@ -679,7 +780,7 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
679780 logging .info ("[Flickr25k] reading images with tag: {}" .format (tag ))
680781 images_list = []
681782 for idx , _v in enumerate (path_tags ):
682- tags = read_file (folder_tags + '/' + path_tags [idx ]).split ('\n ' )
783+ tags = read_file (os . path . join ( folder_tags , path_tags [idx ]) ).split ('\n ' )
683784 # logging.info(idx+1, tags)
684785 if tag is None or tag in tags :
685786 images_list .append (path_imgs [idx ])
@@ -722,6 +823,8 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
722823 >>> images = tl.files.load_flickr1M_dataset(tag='zebra')
723824
724825 """
826+ import shutil
827+
725828 path = os .path .join (path , 'flickr1M' )
726829 logging .info ("[Flickr1M] using {}% of images = {}" .format (size * 10 , size * 100000 ))
727830 images_zip = [
@@ -734,20 +837,21 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
734837 for image_zip in images_zip [0 :size ]:
735838 image_folder = image_zip .split ("." )[0 ]
736839 # logging.info(path+"/"+image_folder)
737- if folder_exists (path + "/" + image_folder ) is False :
840+ if folder_exists (os . path . join ( path , image_folder ) ) is False :
738841 # logging.info(image_zip)
739842 logging .info ("[Flickr1M] {} is missing in {}" .format (image_folder , path ))
740843 maybe_download_and_extract (image_zip , path , url , extract = True )
741- del_file (path + '/' + image_zip )
742- os .system ("mv {} {}" .format (path + '/images' , path + '/' + image_folder ))
844+ del_file (os .path .join (path , image_zip ))
845+ # os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder)))
846+ shutil .move (os .path .join (path , 'images' ), os .path .join (path , image_folder ))
743847 else :
744848 logging .info ("[Flickr1M] {} exists in {}" .format (image_folder , path ))
745849
746850 # download tag
747- if folder_exists (path + "/ tags" ) is False :
851+ if folder_exists (os . path . join ( path , " tags") ) is False :
748852 logging .info ("[Flickr1M] tag files is nonexistent in {}" .format (path ))
749853 maybe_download_and_extract (tag_zip , path , url , extract = True )
750- del_file (path + '/' + tag_zip )
854+ del_file (os . path . join ( path , tag_zip ) )
751855 else :
752856 logging .info ("[Flickr1M] tags exists in {}" .format (path ))
753857
@@ -761,17 +865,19 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
761865 for folder in images_folder_list [0 :size * 10 ]:
762866 tmp = load_file_list (path = folder , regx = '\\ .jpg' , printable = False )
763867 tmp .sort (key = lambda s : int (s .split ('.' )[- 2 ])) # ddd.jpg
764- images_list .extend ([folder + '/' + x for x in tmp ])
868+ images_list .extend ([os . path . join ( folder , x ) for x in tmp ])
765869
766870 # 2. tag path list
767871 tag_list = []
768- tag_folder_list = load_folder_list (path + "/tags" )
769- tag_folder_list .sort (key = lambda s : int (s .split ('/' )[- 1 ])) # folder/images/ddd
872+ tag_folder_list = load_folder_list (os .path .join (path , "tags" ))
873+
874+ # tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd
875+ tag_folder_list .sort (key = lambda s : int (os .path .basename (s )))
770876
771877 for folder in tag_folder_list [0 :size * 10 ]:
772878 tmp = load_file_list (path = folder , regx = '\\ .txt' , printable = False )
773879 tmp .sort (key = lambda s : int (s .split ('.' )[- 2 ])) # ddd.txt
774- tmp = [folder + '/' + s for s in tmp ]
880+ tmp = [os . path . join ( folder , s ) for s in tmp ]
775881 tag_list += tmp
776882
777883 # 3. select images
0 commit comments