Skip to content

Commit e857b29

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Refactor download_manager
PiperOrigin-RevId: 681458845
1 parent c931fd0 commit e857b29

File tree

1 file changed

+58
-75
lines changed

1 file changed

+58
-75
lines changed

tensorflow_datasets/core/download/download_manager.py

Lines changed: 58 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -390,56 +390,55 @@ def _download_or_get_cache(
390390
expected_url_info
391391
):
392392
computed_url_info = checksums.compute_url_info(manually_downloaded_path)
393-
self._register_or_validate_checksums(
394-
resource=resource,
395-
path=manually_downloaded_path,
396-
computed_url_info=computed_url_info,
397-
)
398-
self._log_skip_download(
399-
url=url, url_info=computed_url_info, path=manually_downloaded_path
393+
dl_result = downloader.DownloadResult(
394+
path=manually_downloaded_path, url_info=computed_url_info
400395
)
401-
return promise.Promise.resolve(manually_downloaded_path)
402396

403397
# Force download
404398
elif self._force_download:
405-
return self._download(resource)
399+
dl_result = None
406400

407401
# Download has been cached (checksum known)
408402
elif expected_url_info and resource_lib.is_locally_cached(
409403
checksum_path := self._get_dl_path(resource, expected_url_info.checksum)
410404
):
411-
self._register_or_validate_checksums(
412-
resource=resource,
413-
path=checksum_path,
414-
computed_url_info=expected_url_info,
405+
dl_result = downloader.DownloadResult(
406+
path=checksum_path, url_info=expected_url_info
415407
)
416-
self._log_skip_download(
417-
url=url, url_info=expected_url_info, path=checksum_path
418-
)
419-
return promise.Promise.resolve(checksum_path)
420408

421409
# Download has been cached (checksum unknown)
422410
elif resource_lib.is_locally_cached(
423411
url_path := self._get_dl_path(resource)
424412
):
425-
computed_url_info = downloader.read_url_info(url_path)
426-
if expected_url_info and expected_url_info != computed_url_info:
413+
url_info = downloader.read_url_info(url_path)
414+
415+
if expected_url_info and expected_url_info != url_info:
427416
# If checksums are registered but do not match, trigger a new
428417
# download (e.g. previous file corrupted, checksums updated)
429-
return self._download(resource)
430-
if checksum_path := self._register_or_validate_checksums(
431-
resource=resource, path=url_path, computed_url_info=computed_url_info
432-
):
418+
dl_result = None
419+
elif self._is_checksum_registered(url=url):
433420
# Checksums were registered: Rename -> checksum_path
434-
resource_lib.replace_info_file(url_path, checksum_path)
435-
path = url_path.replace(checksum_path)
421+
path = self._get_dl_path(resource, url_info.checksum)
422+
resource_lib.replace_info_file(url_path, path)
423+
url_path.replace(path)
424+
dl_result = downloader.DownloadResult(path=path, url_info=url_info)
436425
else:
437426
# Checksums not registered: -> do nothing
438-
path = url_path
439-
self._log_skip_download(url=url, url_info=computed_url_info, path=path)
440-
return promise.Promise.resolve(path)
427+
dl_result = downloader.DownloadResult(path=url_path, url_info=url_info)
441428

442429
# Cache not found
430+
else:
431+
dl_result = None
432+
433+
if dl_result:
434+
path = dl_result.path
435+
url_info = dl_result.url_info
436+
437+
self._log_skip_download(url=url, url_info=url_info, path=path)
438+
self._register_or_validate_checksums(
439+
url=url, url_info=url_info, path=path
440+
)
441+
return promise.Promise.resolve(path)
443442
else:
444443
return self._download(resource)
445444

@@ -451,58 +450,40 @@ def _log_skip_download(
451450
self._downloader.increase_tqdm(url_info)
452451

453452
def _register_or_validate_checksums(
454-
self,
455-
resource: resource_lib.Resource,
456-
path: epath.Path,
457-
computed_url_info: checksums.UrlInfo,
458-
) -> epath.Path | None:
459-
"""Validates/records checksums and returns checksum path if registered."""
460-
url = resource.url
461-
453+
self, url: str, url_info: checksums.UrlInfo, path: epath.Path
454+
) -> None:
455+
"""Registers or validates checksums depending on `self._register_checksums`."""
462456
if self._register_checksums:
463457
# Note:
464-
# * We save even if `expected_url_info == computed_url_info` as
458+
# * We save even if `expected_url_info == url_info` as
465459
# `expected_url_info` might have been loaded from another dataset.
466460
# * `register_checksums_path` was validated in `__init__` so this
467461
# shouldn't fail.
468-
self._recorded_url_infos[url] = computed_url_info
462+
self._recorded_url_infos[url] = url_info
469463
self._record_url_infos()
470-
return self._get_dl_path(resource, computed_url_info.checksum)
471464
elif expected_url_info := self._url_infos.get(url):
472465
# Eventually validate checksums
473-
# Note:
474-
# * If path is cached at `url_path` but cached
475-
# `computed_url_info != expected_url_info`, a new download has
476-
# been triggered (as _get_cached_path returns None)
477-
# * If path was downloaded but checksums don't match expected, then
478-
# the download isn't cached (re-running build will retrigger a new
479-
# download). This is expected as it might mean the downloaded file
480-
# was corrupted. Note: The tmp file isn't deleted to allow inspection.
481-
self._validate_checksums(
482-
url=url,
483-
expected_url_info=expected_url_info,
484-
computed_url_info=computed_url_info,
485-
path=path,
486-
)
487-
return self._get_dl_path(resource, expected_url_info.checksum)
488-
489-
def _validate_checksums(
490-
self,
491-
url: str,
492-
expected_url_info: checksums.UrlInfo,
493-
computed_url_info: checksums.UrlInfo,
494-
path: epath.Path,
495-
) -> None:
496-
"""Validate computed_url_info match expected_url_info."""
497-
if expected_url_info != computed_url_info:
498-
msg = (
499-
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
500-
f'* Expected: {expected_url_info}\n'
501-
f'* Got: {computed_url_info}\n'
502-
'To debug, see: '
503-
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
504-
)
505-
raise NonMatchingChecksumError(msg)
466+
if expected_url_info != url_info:
467+
msg = (
468+
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
469+
f'* Expected: {expected_url_info}\n'
470+
f'* Got: {url_info}\n'
471+
'To debug, see: '
472+
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
473+
)
474+
raise NonMatchingChecksumError(msg)
475+
476+
def _is_checksum_registered(self, url: str) -> bool:
477+
"""Returns whether checksums are registered for the given url."""
478+
if url in self._url_infos:
479+
# Checksum is already registered
480+
return True
481+
elif self._register_checksums:
482+
# Checksum is being registered
483+
return True
484+
else:
485+
# Checksum is not registered
486+
return False
506487

507488
def _download(
508489
self, resource: resource_lib.Resource
@@ -534,10 +515,12 @@ def callback(dl_result: downloader.DownloadResult) -> epath.Path:
534515
dl_path = dl_result.path
535516
dl_url_info = dl_result.url_info
536517

537-
dst_path = self._register_or_validate_checksums(
538-
resource=resource, computed_url_info=dl_url_info, path=dl_path
518+
self._register_or_validate_checksums(
519+
url=url, url_info=dl_url_info, path=dl_path
539520
)
540-
if not dst_path:
521+
if self._is_checksum_registered(url=url):
522+
dst_path = self._get_dl_path(resource, dl_url_info.checksum)
523+
else:
541524
dst_path = url_path
542525

543526
resource_lib.write_info_file(

0 commit comments

Comments
 (0)