@@ -649,6 +649,62 @@ def if_2d_to_3d(images): # [h, w] --> [h, w, 3]
649649
650650 return im_train_A , im_train_B , im_test_A , im_test_B
651651
652+ def download_file_from_google_drive (id , destination ):
653+ """ Download file from Google Driver, see ``load_celeba_dataset`` for example.
654+
655+ Parameters
656+ --------------
657+ id : driver ID
658+ destination : string, save path.
659+ """
660+ from tqdm import tqdm
661+ import requests
662+ def save_response_content (response , destination , chunk_size = 32 * 1024 ):
663+ total_size = int (response .headers .get ('content-length' , 0 ))
664+ with open (destination , "wb" ) as f :
665+ for chunk in tqdm (response .iter_content (chunk_size ), total = total_size ,
666+ unit = 'B' , unit_scale = True , desc = destination ):
667+ if chunk : # filter out keep-alive new chunks
668+ f .write (chunk )
669+ def get_confirm_token (response ):
670+ for key , value in response .cookies .items ():
671+ if key .startswith ('download_warning' ):
672+ return value
673+ return None
674+ URL = "https://docs.google.com/uc?export=download"
675+ session = requests .Session ()
676+
677+ response = session .get (URL , params = { 'id' : id }, stream = True )
678+ token = get_confirm_token (response )
679+
680+ if token :
681+ params = { 'id' : id , 'confirm' : token }
682+ response = session .get (URL , params = params , stream = True )
683+ save_response_content (response , destination )
684+
685+ def load_celebA_dataset (dirpath = 'data' ):
686+ """ Automatically download celebA dataset, and return a list of image path. """
687+ import zipfile , os
688+ data_dir = 'celebA'
689+ filename , drive_id = "img_align_celeba.zip" , "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
690+ save_path = os .path .join (dirpath , filename )
691+ image_path = os .path .join (dirpath , data_dir )
692+ if os .path .exists (image_path ):
693+ print ('[*] {} already exists' .format (save_path ))
694+ else :
695+ exists_or_mkdir (dirpath )
696+ download_file_from_google_drive (drive_id , save_path )
697+ zip_dir = ''
698+ with zipfile .ZipFile (save_path ) as zf :
699+ zip_dir = zf .namelist ()[0 ]
700+ zf .extractall (dirpath )
701+ os .remove (save_path )
702+ os .rename (os .path .join (dirpath , zip_dir ), image_path )
703+
704+ data_files = load_file_list (path = image_path , regx = '\\ .jpg' , printable = False )
705+ for i in range (len (data_files )):
706+ data_files [i ] = os .path .join (image_path , data_files [i ])
707+ return data_files
652708
653709## Load and save network list npz
654710def save_npz (save_list = [], name = 'model.npz' , sess = None ):
0 commit comments