Skip to content

Commit 3b0dab2

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Refactor download_manager.py
PiperOrigin-RevId: 676819433
1 parent 2c16950 commit 3b0dab2

File tree

5 files changed

+216
-222
lines changed

5 files changed

+216
-222
lines changed

tensorflow_datasets/core/download/checksums.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def _default_checksum_dirs() -> list[epath.Path]:
3636
]
3737

3838

39+
def sha256(str_: str) -> str:
40+
return hashlib.sha256(str_.encode()).hexdigest()
41+
42+
3943
@dataclasses.dataclass(eq=True)
4044
class UrlInfo:
4145
"""Small wrapper around the url metadata (checksum, size).

tensorflow_datasets/core/download/download_manager.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import concurrent.futures
2222
import dataclasses
2323
import functools
24-
import hashlib
2524
import typing
2625
from typing import Any
2726
import uuid
@@ -316,8 +315,8 @@ def downloaded_size(self):
316315
"""Returns the total size of downloaded files."""
317316
return sum(url_info.size for url_info in self._recorded_url_infos.values())
318317

319-
def _get_dl_path(self, url: str, sha256: str) -> epath.Path:
320-
return self._download_dir / resource_lib.get_dl_fname(url, sha256)
318+
def _get_dl_path(self, url: str, checksum: str | None = None) -> epath.Path:
319+
return self._download_dir / resource_lib.get_dl_fname(url, checksum)
321320

322321
@property
323322
def register_checksums(self):
@@ -368,11 +367,9 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
368367
manual_dir=self._manual_dir,
369368
expected_url_info=expected_url_info,
370369
)
371-
url_path = self._get_dl_path(
372-
url, sha256=hashlib.sha256(url.encode('utf-8')).hexdigest()
373-
)
370+
url_path = self._get_dl_path(url)
374371
checksum_path = (
375-
self._get_dl_path(url, sha256=expected_url_info.checksum)
372+
self._get_dl_path(url, expected_url_info.checksum)
376373
if expected_url_info
377374
else None
378375
)
@@ -392,10 +389,11 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
392389
self._downloader.increase_tqdm(dl_result)
393390
future = promise.Promise.resolve(dl_result)
394391
else:
395-
# Download in an empty tmp directory (to avoid name collisions)
392+
# Download in a tmp directory next to url_path (to avoid name collisions)
396393
# `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path`
397-
dirname = f'{resource_lib.get_dl_dirname(url)}.tmp.{uuid.uuid4().hex}'
398-
download_tmp_dir = self._download_dir / dirname
394+
download_tmp_dir = (
395+
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
396+
)
399397
download_tmp_dir.mkdir()
400398
logging.info(f'Downloading {url} into {download_tmp_dir}...')
401399
future = self._downloader.download(

0 commit comments

Comments
 (0)