@@ -27,19 +27,21 @@ def load_celebA_dataset(path='data'):
2727
2828 data_dir = 'celebA'
2929 filename , drive_id = "img_align_celeba.zip" , "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
30- save_path = os .path .join (path , filename )
31- image_path = os .path .join (path , data_dir )
30+ file_path = os .path .join (path , data_dir )
31+ image_path = os .path .join (path , data_dir , "img_align_celeba" )
32+ save_path = os .path .join (path , data_dir , filename )
3233 if os .path .exists (image_path ):
33- logging .info ('[*] {} already exists' .format (save_path ))
34+ logging .info ('[*] {} already exists' .format (image_path ))
3435 else :
35- exists_or_mkdir (path )
36- download_file_from_google_drive (drive_id , save_path )
37- zip_dir = ''
36+ if not os .path .exists (save_path ):
37+ exists_or_mkdir (file_path )
38+ download_file_from_google_drive (drive_id , save_path )
39+ zip_dir = ''
3840 with zipfile .ZipFile (save_path ) as zf :
3941 zip_dir = zf .namelist ()[0 ]
40- zf .extractall (path )
41- os .remove (save_path )
42- os .rename (os .path .join (path , zip_dir ), image_path )
42+ zf .extractall (file_path )
43+ # os.remove(save_path)
44+ # os.rename(os.path.join(path, zip_dir), image_path)
4345
4446 data_files = load_file_list (path = image_path , regx = '\\ .jpg' , printable = False )
4547 for i , _v in enumerate (data_files ):
0 commit comments