Skip to content

Commit 29169ac

Browse files
committed
release load celebA
1 parent a686af5 commit 29169ac

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

docs/modules/files.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ sake of cross-platform.
4444
load_flickr25k_dataset
4545
load_flickr1M_dataset
4646
load_cyclegan_dataset
47+
load_celebA_dataset
48+
download_file_from_google_drive
4749

4850
save_npz
4951
load_npz
@@ -115,7 +117,13 @@ CycleGAN
115117
^^^^^^^^^^^^^^^^^^^^^^^^^
116118
.. autofunction:: load_cyclegan_dataset
117119

120+
CelebA
121+
^^^^^^^^^
122+
.. autofunction:: load_celebA_dataset
118123

124+
Google Drive
125+
^^^^^^^^^^^^^^^^
126+
.. autofunction:: download_file_from_google_drive
119127

120128
Load and save network
121129
----------------------

tensorlayer/files.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,62 @@ def if_2d_to_3d(images): # [h, w] --> [h, w, 3]
649649

650650
return im_train_A, im_train_B, im_test_A, im_test_B
651651

652+
def download_file_from_google_drive(id, destination):
653+
""" Download file from Google Driver, see ``load_celeba_dataset`` for example.
654+
655+
Parameters
656+
--------------
657+
id : driver ID
658+
destination : string, save path.
659+
"""
660+
from tqdm import tqdm
661+
import requests
662+
def save_response_content(response, destination, chunk_size=32*1024):
663+
total_size = int(response.headers.get('content-length', 0))
664+
with open(destination, "wb") as f:
665+
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
666+
unit='B', unit_scale=True, desc=destination):
667+
if chunk: # filter out keep-alive new chunks
668+
f.write(chunk)
669+
def get_confirm_token(response):
670+
for key, value in response.cookies.items():
671+
if key.startswith('download_warning'):
672+
return value
673+
return None
674+
URL = "https://docs.google.com/uc?export=download"
675+
session = requests.Session()
676+
677+
response = session.get(URL, params={ 'id': id }, stream=True)
678+
token = get_confirm_token(response)
679+
680+
if token:
681+
params = { 'id' : id, 'confirm' : token }
682+
response = session.get(URL, params=params, stream=True)
683+
save_response_content(response, destination)
684+
685+
def load_celebA_dataset(dirpath='data'):
686+
""" Automatically download celebA dataset, and return a list of image path. """
687+
import zipfile, os
688+
data_dir = 'celebA'
689+
filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
690+
save_path = os.path.join(dirpath, filename)
691+
image_path = os.path.join(dirpath, data_dir)
692+
if os.path.exists(image_path):
693+
print('[*] {} already exists'.format(save_path))
694+
else:
695+
exists_or_mkdir(dirpath)
696+
download_file_from_google_drive(drive_id, save_path)
697+
zip_dir = ''
698+
with zipfile.ZipFile(save_path) as zf:
699+
zip_dir = zf.namelist()[0]
700+
zf.extractall(dirpath)
701+
os.remove(save_path)
702+
os.rename(os.path.join(dirpath, zip_dir), image_path)
703+
704+
data_files = load_file_list(path=image_path, regx='\\.jpg', printable=False)
705+
for i in range(len(data_files)):
706+
data_files[i] = os.path.join(image_path, data_files[i])
707+
return data_files
652708

653709
## Load and save network list npz
654710
def save_npz(save_list=[], name='model.npz', sess=None):

0 commit comments

Comments
 (0)