Skip to content

Commit c37ca97

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Refactor download_manager.py
PiperOrigin-RevId: 679662316
1 parent da34559 commit c37ca97

File tree

2 files changed

+115
-164
lines changed

2 files changed

+115
-164
lines changed

tensorflow_datasets/core/download/download_manager.py

Lines changed: 114 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -355,15 +355,11 @@ def _get_manually_downloaded_path(
355355
# processed once, even if passed twice to download_manager.
356356
@utils.build_synchronize_decorator()
357357
@utils.memoize()
358-
def _download(self, resource: Url) -> promise.Promise[epath.Path]:
358+
def _download_or_get_cache(
359+
self, resource: Url
360+
) -> promise.Promise[epath.Path]:
359361
"""Downloads resource or gets downloaded cache.
360362
361-
This function:
362-
363-
1. Reuse cache (`_get_cached_path`) or download the file
364-
2. Register or validate checksums (`_register_or_validate_checksums`)
365-
3. Rename download to final path (`_rename_and_get_final_dl_path`)
366-
367363
Args:
368364
resource: The URL to download.
369365
@@ -378,76 +374,79 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
378374

379375
expected_url_info = self._url_infos.get(url)
380376

381-
# 3 possible destinations for the path:
382-
# * In `manual_dir` (manually downloaded data)
383-
# * In `downloads/url_path` (checksum unknown)
384-
# * In `downloads/checksum_path` (checksum registered)
385-
manually_downloaded_path = self._get_manually_downloaded_path(
386-
expected_url_info=expected_url_info
387-
)
388-
url_path = self._get_dl_path(resource)
389-
checksum_path = (
390-
self._get_dl_path(resource, expected_url_info.checksum)
391-
if expected_url_info
392-
else None
393-
)
394-
395-
# Get the cached path and url_info (if they exists)
396-
dl_result = downloader.get_cached_path(
397-
manually_downloaded_path=manually_downloaded_path,
398-
checksum_path=checksum_path,
399-
url_path=url_path,
400-
expected_url_info=expected_url_info,
401-
)
402-
if dl_result and not self._force_download: # Download was cached
403-
logging.info(
404-
f'Skipping download of {url}: File cached in {dl_result.path}'
377+
# User has manually downloaded the file.
378+
if manually_downloaded_path := self._get_manually_downloaded_path(
379+
expected_url_info
380+
):
381+
computed_url_info = checksums.compute_url_info(manually_downloaded_path)
382+
self._register_or_validate_checksums(
383+
resource=resource,
384+
path=manually_downloaded_path,
385+
computed_url_info=computed_url_info,
405386
)
406-
# Still update the progression bar to indicate the file was downloaded
407-
self._downloader.increase_tqdm(dl_result.url_info)
408-
future = promise.Promise.resolve(dl_result)
409-
else:
410-
# Download in a tmp directory next to url_path (to avoid name collisions)
411-
# `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path`
412-
download_tmp_dir = (
413-
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
387+
self._log_skip_download(
388+
url=url, url_info=computed_url_info, path=manually_downloaded_path
414389
)
415-
download_tmp_dir.mkdir()
416-
logging.info(f'Downloading {url} into {download_tmp_dir}...')
417-
future = self._downloader.download(
418-
url, download_tmp_dir, verify=self._verify_ssl
390+
return promise.Promise.resolve(manually_downloaded_path)
391+
392+
# Force download
393+
elif self._force_download:
394+
return self._download(resource)
395+
396+
# Download has been cached (checksum known)
397+
elif expected_url_info and resource_lib.Resource.exists_locally(
398+
checksum_path := self._get_dl_path(resource, expected_url_info.checksum)
399+
):
400+
self._register_or_validate_checksums(
401+
resource=resource,
402+
path=checksum_path,
403+
computed_url_info=expected_url_info,
419404
)
405+
self._log_skip_download(
406+
url=url, url_info=expected_url_info, path=checksum_path
407+
)
408+
return promise.Promise.resolve(checksum_path)
409+
410+
# Download has been cached (checksum unknown)
411+
elif resource_lib.Resource.exists_locally(
412+
url_path := self._get_dl_path(resource)
413+
):
414+
computed_url_info = downloader.read_url_info(url_path)
415+
if expected_url_info and expected_url_info != computed_url_info:
416+
# If checksums are registered but do not match, trigger a new
417+
# download (e.g. previous file corrupted, checksums updated)
418+
return self._download(resource)
419+
if checksum_path := self._register_or_validate_checksums(
420+
resource=resource, path=url_path, computed_url_info=computed_url_info
421+
):
422+
# Checksums were registered: Rename -> checksum_path
423+
resource_lib.replace_info_file(url_path, checksum_path)
424+
path = url_path.replace(checksum_path)
425+
else:
426+
# Checksums not registered: -> do nothing
427+
path = url_path
428+
self._log_skip_download(url=url, url_info=computed_url_info, path=path)
429+
return promise.Promise.resolve(path)
420430

421-
# Post-process the result
422-
return future.then(
423-
lambda dl_result: self._register_or_validate_checksums( # pylint: disable=g-long-lambda
424-
resource=resource,
425-
path=dl_result.path,
426-
computed_url_info=dl_result.url_info,
427-
expected_url_info=expected_url_info,
428-
checksum_path=checksum_path,
429-
url_path=url_path,
430-
)
431-
)
431+
# Cache not found
432+
else:
433+
return self._download(resource)
434+
435+
def _log_skip_download(
436+
self, url: str, url_info: checksums.UrlInfo, path: epath.Path
437+
) -> None:
438+
logging.info(f'Skipping download of {url}: File cached in {path}')
439+
# Still update the progression bar to indicate the file was downloaded
440+
self._downloader.increase_tqdm(url_info)
432441

433442
def _register_or_validate_checksums(
434443
self,
435444
resource: resource_lib.Resource,
436445
path: epath.Path,
437-
expected_url_info: checksums.UrlInfo | None,
438446
computed_url_info: checksums.UrlInfo,
439-
checksum_path: epath.Path | None,
440-
url_path: epath.Path,
441-
) -> epath.Path:
442-
"""Validates/records checksums and renames final downloaded path."""
443-
# `path` can be:
444-
# * Manually downloaded
445-
# * (cached) checksum_path
446-
# * (cached) url_path
447-
# * `tmp_dir/file` (downloaded path)
448-
447+
) -> epath.Path | None:
448+
"""Validates/records checksums and returns checksum path if registered."""
449449
url: str = resource.url # pytype: disable=annotation-type-mismatch
450-
# Used both in `.downloaded_size` and `_record_url_infos()`
451450
self._recorded_url_infos[url] = computed_url_info
452451

453452
if self._register_checksums:
@@ -457,12 +456,9 @@ def _register_or_validate_checksums(
457456
# * `register_checksums_path` was validated in `__init__` so this
458457
# shouldn't fail.
459458
self._record_url_infos()
460-
461-
# Checksum path should now match the new registered checksum (even if
462-
# checksums were previously registered)
463-
expected_url_info = computed_url_info
464-
checksum_path = self._get_dl_path(resource, computed_url_info.checksum)
459+
return self._get_dl_path(resource, computed_url_info.checksum)
465460
else:
461+
expected_url_info = self._url_infos.get(url)
466462
# Eventually validate checksums
467463
# Note:
468464
# * If path is cached at `url_path` but cached
@@ -478,15 +474,8 @@ def _register_or_validate_checksums(
478474
computed_url_info=computed_url_info,
479475
path=path,
480476
)
481-
482-
return self._rename_and_get_final_dl_path(
483-
url=url,
484-
path=path,
485-
expected_url_info=expected_url_info,
486-
computed_url_info=computed_url_info,
487-
checksum_path=checksum_path,
488-
url_path=url_path,
489-
)
477+
if expected_url_info:
478+
return self._get_dl_path(resource, expected_url_info.checksum)
490479

491480
def _validate_checksums(
492481
self,
@@ -517,47 +506,56 @@ def _validate_checksums(
517506
)
518507
raise NonMatchingChecksumError(msg)
519508

520-
def _rename_and_get_final_dl_path(
521-
self,
522-
url: str,
523-
path: epath.Path,
524-
expected_url_info: checksums.UrlInfo | None,
525-
computed_url_info: checksums.UrlInfo | None,
526-
checksum_path: epath.Path | None,
527-
url_path: epath.Path,
528-
) -> epath.Path:
529-
"""Eventually rename the downloaded file if checksums were recorded."""
530-
# `path` can be:
531-
# * Manually downloaded
532-
# * (cached) checksum_path
533-
# * (cached) url_path
534-
# * `tmp_dir/file` (downloaded path)
535-
if self._manual_dir and path.is_relative_to(self._manual_dir):
536-
return path # Manually downloaded data
537-
elif path == checksum_path: # Path already at final destination
538-
assert computed_url_info == expected_url_info # Sanity check
539-
return checksum_path # pytype: disable=bad-return-type
540-
elif path == url_path:
541-
if checksum_path:
542-
# Checksums were registered: Rename -> checksums_path
543-
resource_lib.replace_info_file(path, checksum_path)
544-
return path.replace(checksum_path)
545-
else:
546-
# Checksums not registered: -> do nothing
547-
return path
548-
else: # Path was downloaded in tmp dir
549-
dst_path = checksum_path or url_path
509+
def _download(
510+
self, resource: resource_lib.Resource
511+
) -> promise.Promise[epath.Path]:
512+
"""Downloads resource.
513+
514+
Args:
515+
resource: The resource to download.
516+
517+
Returns:
518+
Promise of the path to the downloaded url.
519+
"""
520+
url_path = self._get_dl_path(resource)
521+
url: str = resource.url # pytype: disable=annotation-type-mismatch
522+
523+
# Download in a tmp directory next to url_path (to avoid name collisions)
524+
# `download_tmp_dir` is cleaned-up in `callback`
525+
download_tmp_dir = (
526+
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
527+
)
528+
download_tmp_dir.mkdir()
529+
logging.info(f'Downloading {url} into {download_tmp_dir}...')
530+
future = self._downloader.download(
531+
url, download_tmp_dir, verify=self._verify_ssl
532+
)
533+
534+
def callback(dl_result: downloader.DownloadResult) -> epath.Path:
535+
"""Post-process the download result."""
536+
dl_path = dl_result.path
537+
dl_url_info = dl_result.url_info
538+
539+
dst_path = self._register_or_validate_checksums(
540+
resource=resource, computed_url_info=dl_url_info, path=dl_path
541+
)
542+
if not dst_path:
543+
dst_path = url_path
544+
550545
resource_lib.write_info_file(
551546
url=url,
552547
path=dst_path,
553548
dataset_name=self._dataset_name,
554-
original_fname=path.name,
555-
url_info=computed_url_info,
549+
original_fname=dl_path.name,
550+
url_info=dl_url_info,
556551
)
557-
path.replace(dst_path)
558-
path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty)
552+
dl_path.replace(dst_path)
553+
dl_path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty)
554+
559555
return dst_path
560556

557+
return future.then(callback)
558+
561559
@utils.build_synchronize_decorator()
562560
@utils.memoize()
563561
def _extract(self, resource: ExtractPath) -> promise.Promise[epath.Path]:
@@ -587,7 +585,7 @@ def callback(path):
587585
resource.path = path
588586
return self._extract(resource)
589587

590-
return self._download(resource).then(callback)
588+
return self._download_or_get_cache(resource).then(callback)
591589

592590
def download_checksums(self, checksums_url):
593591
"""Downloads checksum file from the given URL and adds it to registry."""
@@ -636,7 +634,7 @@ def download(self, url_or_urls):
636634
"""
637635
# Add progress bar to follow the download state
638636
with self._downloader.tqdm():
639-
return _map_promise(self._download, url_or_urls)
637+
return _map_promise(self._download_or_get_cache, url_or_urls)
640638

641639
def iter_archive(
642640
self,

tensorflow_datasets/core/download/downloader.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader':
6161
return _Downloader(*args, **kwargs)
6262

6363

64-
def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
64+
def read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
6565
"""Loads the `UrlInfo` from the `.INFO` file."""
6666
file_info = resource_lib.read_info_file(url_path)
6767
if 'url_info' not in file_info:
@@ -75,53 +75,6 @@ def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
7575
return checksums_lib.UrlInfo(**url_info)
7676

7777

78-
def get_cached_path(
79-
manually_downloaded_path: epath.Path | None,
80-
checksum_path: epath.Path | None,
81-
url_path: epath.Path,
82-
expected_url_info: checksums_lib.UrlInfo | None,
83-
) -> DownloadResult | None:
84-
"""Returns the downloaded path and computed url-info.
85-
86-
If the path is not cached, or that `url_path` does not match checksums,
87-
the file will be downloaded again.
88-
89-
Path can be cached at three different locations:
90-
91-
Args:
92-
manually_downloaded_path: Manually downloaded in `dl_manager.manual_dir`
93-
checksum_path: Cached in the final destination (if checksum known)
94-
url_path: Cached in the tmp destination (if checksum unknown).
95-
expected_url_info: Registered checksum (if known)
96-
"""
97-
# User has manually downloaded the file.
98-
if manually_downloaded_path and manually_downloaded_path.exists():
99-
computed_url_info = checksums_lib.compute_url_info(manually_downloaded_path)
100-
return DownloadResult(
101-
path=manually_downloaded_path, url_info=computed_url_info
102-
)
103-
104-
# Download has been cached (checksum known)
105-
elif checksum_path and resource_lib.Resource.exists_locally(checksum_path):
106-
# `path = f(checksum)` was found, so url_info match
107-
return DownloadResult(checksum_path, url_info=expected_url_info)
108-
109-
# Download has been cached (checksum unknown)
110-
elif resource_lib.Resource.exists_locally(url_path):
111-
# Info restored from `.INFO` file
112-
computed_url_info = _read_url_info(url_path)
113-
# If checksums are now registered but do not match, trigger a new
114-
# download (e.g. previous file corrupted, checksums updated)
115-
if expected_url_info and computed_url_info != expected_url_info:
116-
return None
117-
else:
118-
return DownloadResult(path=url_path, url_info=computed_url_info)
119-
120-
# Else file not found (or has bad checksums). (re)download.
121-
else:
122-
return None
123-
124-
12578
def _filename_from_content_disposition(
12679
content_disposition: str,
12780
) -> str | None:

0 commit comments

Comments
 (0)