Skip to content

Commit 7dabe49

Browse files
authored
fix celebA dataset load (#66)
1 parent 73d4872 commit 7dabe49

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

tensorlayerx/files/dataset_loaders/celebA_dataset.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,21 @@ def load_celebA_dataset(path='data'):
2727

2828
data_dir = 'celebA'
2929
filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
30-
save_path = os.path.join(path, filename)
31-
image_path = os.path.join(path, data_dir)
30+
file_path = os.path.join(path, data_dir)
31+
image_path = os.path.join(path, data_dir, "img_align_celeba")
32+
save_path = os.path.join(path, data_dir, filename)
3233
if os.path.exists(image_path):
33-
logging.info('[*] {} already exists'.format(save_path))
34+
logging.info('[*] {} already exists'.format(image_path))
3435
else:
35-
exists_or_mkdir(path)
36-
download_file_from_google_drive(drive_id, save_path)
37-
zip_dir = ''
36+
if not os.path.exists(save_path):
37+
exists_or_mkdir(file_path)
38+
download_file_from_google_drive(drive_id, save_path)
39+
zip_dir = ''
3840
with zipfile.ZipFile(save_path) as zf:
3941
zip_dir = zf.namelist()[0]
40-
zf.extractall(path)
41-
os.remove(save_path)
42-
os.rename(os.path.join(path, zip_dir), image_path)
42+
zf.extractall(file_path)
43+
# os.remove(save_path)
44+
# os.rename(os.path.join(path, zip_dir), image_path)
4345

4446
data_files = load_file_list(path=image_path, regx='\\.jpg', printable=False)
4547
for i, _v in enumerate(data_files):

0 commit comments

Comments
 (0)