Skip to content

Commit 0015e96

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Support register checksums for manually downloaded files.
PiperOrigin-RevId: 678584247
1 parent 6df5fc6 commit 0015e96

File tree

4 files changed

+61
-65
lines changed

4 files changed

+61
-65
lines changed

tensorflow_datasets/core/download/download_manager.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,12 @@ def downloaded_size(self) -> int:
315315
"""Returns the total size of downloaded files."""
316316
return sum(url_info.size for url_info in self._recorded_url_infos.values())
317317

318-
def _get_dl_path(self, url: str, checksum: str | None = None) -> epath.Path:
319-
return self._download_dir / resource_lib.get_dl_fname(url, checksum)
318+
def _get_dl_path(
319+
self, resource: resource_lib.Resource, checksum: str | None = None
320+
) -> epath.Path:
321+
return self._download_dir / resource_lib.get_dl_fname(
322+
resource.url, checksum
323+
)
320324

321325
@property
322326
def register_checksums(self):
@@ -352,7 +356,7 @@ def _get_manually_downloaded_path(
352356
@utils.build_synchronize_decorator()
353357
@utils.memoize()
354358
def _download(self, resource: Url) -> promise.Promise[epath.Path]:
355-
"""Download resource, returns Promise->path to downloaded file.
359+
"""Downloads resource or gets downloaded cache.
356360
357361
This function:
358362
@@ -364,13 +368,12 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
364368
resource: The URL to download.
365369
366370
Returns:
367-
path: The path to the downloaded resource.
371+
Promise of the path to the downloaded resource.
368372
"""
369373
# Normalize the input
370-
if isinstance(resource, str):
371-
url = resource
372-
else:
373-
url = resource.url
374+
if not isinstance(resource, resource_lib.Resource):
375+
resource = resource_lib.Resource(url=resource)
376+
url = resource.url
374377
assert url is not None, 'URL is undefined from resource.'
375378

376379
expected_url_info = self._url_infos.get(url)
@@ -382,9 +385,9 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
382385
manually_downloaded_path = self._get_manually_downloaded_path(
383386
expected_url_info=expected_url_info
384387
)
385-
url_path = self._get_dl_path(url)
388+
url_path = self._get_dl_path(resource)
386389
checksum_path = (
387-
self._get_dl_path(url, expected_url_info.checksum)
390+
self._get_dl_path(resource, expected_url_info.checksum)
388391
if expected_url_info
389392
else None
390393
)
@@ -396,12 +399,12 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
396399
url_path=url_path,
397400
expected_url_info=expected_url_info,
398401
)
399-
if dl_result.path and not self._force_download: # Download was cached
402+
if dl_result and not self._force_download: # Download was cached
400403
logging.info(
401404
f'Skipping download of {url}: File cached in {dl_result.path}'
402405
)
403406
# Still update the progression bar to indicate the file was downloaded
404-
self._downloader.increase_tqdm(dl_result)
407+
self._downloader.increase_tqdm(dl_result.url_info)
405408
future = promise.Promise.resolve(dl_result)
406409
else:
407410
# Download in a tmp directory next to url_path (to avoid name collisions)
@@ -418,7 +421,7 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
418421
# Post-process the result
419422
return future.then(
420423
lambda dl_result: self._register_or_validate_checksums( # pylint: disable=g-long-lambda
421-
url=url,
424+
resource=resource,
422425
path=dl_result.path,
423426
computed_url_info=dl_result.url_info,
424427
expected_url_info=expected_url_info,
@@ -429,10 +432,10 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
429432

430433
def _register_or_validate_checksums(
431434
self,
435+
resource: resource_lib.Resource,
432436
path: epath.Path,
433-
url: str,
434437
expected_url_info: checksums.UrlInfo | None,
435-
computed_url_info: checksums.UrlInfo | None,
438+
computed_url_info: checksums.UrlInfo,
436439
checksum_path: epath.Path | None,
437440
url_path: epath.Path,
438441
) -> epath.Path:
@@ -443,16 +446,11 @@ def _register_or_validate_checksums(
443446
# * (cached) url_path
444447
# * `tmp_dir/file` (downloaded path)
445448

446-
if computed_url_info:
447-
# Used both in `.downloaded_size` and `_record_url_infos()`
448-
self._recorded_url_infos[url] = computed_url_info
449+
url: str = resource.url # pytype: disable=annotation-type-mismatch
450+
# Used both in `.downloaded_size` and `_record_url_infos()`
451+
self._recorded_url_infos[url] = computed_url_info
449452

450453
if self._register_checksums:
451-
if not computed_url_info:
452-
raise ValueError(
453-
f'Cannot register checksums for {url}: no computed checksum. '
454-
'--register_checksums with manually downloaded data not supported.'
455-
)
456454
# Note:
457455
# * We save even if `expected_url_info == computed_url_info` as
458456
# `expected_url_info` might have been loaded from another dataset.
@@ -463,7 +461,7 @@ def _register_or_validate_checksums(
463461
# Checksum path should now match the new registered checksum (even if
464462
# checksums were previously registered)
465463
expected_url_info = computed_url_info
466-
checksum_path = self._get_dl_path(url, computed_url_info.checksum)
464+
checksum_path = self._get_dl_path(resource, computed_url_info.checksum)
467465
else:
468466
# Eventually validate checksums
469467
# Note:
@@ -476,9 +474,9 @@ def _register_or_validate_checksums(
476474
# was corrupted. Note: The tmp file isn't deleted to allow inspection.
477475
self._validate_checksums(
478476
url=url,
479-
path=path,
480477
expected_url_info=expected_url_info,
481478
computed_url_info=computed_url_info,
479+
path=path,
482480
)
483481

484482
return self._rename_and_get_final_dl_path(
@@ -493,17 +491,14 @@ def _register_or_validate_checksums(
493491
def _validate_checksums(
494492
self,
495493
url: str,
496-
path: epath.Path,
497-
computed_url_info: checksums.UrlInfo | None,
498494
expected_url_info: checksums.UrlInfo | None,
495+
computed_url_info: checksums.UrlInfo,
496+
path: epath.Path,
499497
) -> None:
500498
"""Validate computed_url_info match expected_url_info."""
501499
# If force-checksums validations, both expected and computed url_info
502500
# should exists
503501
if self._force_checksums_validation:
504-
# Checksum of the downloaded file unknown (for manually downloaded file)
505-
if not computed_url_info:
506-
computed_url_info = checksums.compute_url_info(path)
507502
# Checksums have not been registered
508503
if not expected_url_info:
509504
raise ValueError(
@@ -512,11 +507,7 @@ def _validate_checksums(
512507
'Did you forget to register checksums?'
513508
)
514509

515-
if (
516-
expected_url_info
517-
and computed_url_info
518-
and expected_url_info != computed_url_info
519-
):
510+
if expected_url_info and expected_url_info != computed_url_info:
520511
msg = (
521512
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
522513
f'* Expected: {expected_url_info}\n'

tensorflow_datasets/core/download/download_manager_test.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,23 @@
4646
class Artifact:
4747
# For testing only.
4848

49-
def __init__(self, name, url=None, content=None):
50-
url = url or f'http://foo-bar.ch/{name}'
51-
content = content or f'content of {name}'
52-
self.url = url
49+
def __init__(
50+
self, name: str, url: str | None = None, content: str | None = None
51+
):
52+
self.name = name
53+
self.url = url or f'http://foo-bar.ch/{self.name}'
54+
self.content = content or f'content of {self.name}'
55+
5356
self.url_info = checksums_lib.UrlInfo(
54-
size=len(content),
55-
checksum=checksums_lib.sha256(content),
56-
filename=name,
57+
size=len(self.content),
58+
checksum=checksums_lib.sha256(self.content),
59+
filename=self.name,
5760
)
5861

59-
self.file_name = resource_lib.get_dl_fname(url, self.url_info.checksum)
62+
self.file_name = resource_lib.get_dl_fname(self.url, self.url_info.checksum)
6063
self.file_path = _DOWNLOAD_DIR / self.file_name
6164

62-
self.url_name = resource_lib.get_dl_fname(url)
65+
self.url_name = resource_lib.get_dl_fname(self.url)
6366
self.url_path = _DOWNLOAD_DIR / self.url_name
6467

6568
self.manual_path = _MANUAL_DIR / name
@@ -91,17 +94,17 @@ class DownloadManagerTest(testing.TestCase, parameterized.TestCase):
9194
def _make_downloader_mock(self):
9295
"""`downloader.download` patch which creates the returns the path."""
9396

94-
def _download(url, tmpdir_path, verify):
97+
def _download(url: str, tmpdir_path: epath.Path, verify: bool):
9598
del verify
9699
self.downloaded_urls.append(url) # Record downloader.download() calls
97100
# If the name isn't explicitly provided, then it is extracted from the
98101
# url.
99102
filename = self.dl_fnames.get(url, os.path.basename(url))
100103
# Save the file in the tmp_dir
101-
path = os.path.join(tmpdir_path, filename)
104+
path = tmpdir_path / filename
102105
self.fs.add_file(path)
103106
dl_result = downloader.DownloadResult(
104-
path=epath.Path(path),
107+
path=path,
105108
url_info=self.dl_results[url],
106109
)
107110
return promise.Promise.resolve(dl_result)
@@ -224,7 +227,7 @@ def test_manually_downloaded(self):
224227
a, b = [Artifact(i) for i in 'ab']
225228

226229
# File a is manually downloaded
227-
self.fs.add_file(a.manual_path)
230+
self.fs.add_file(a.manual_path, content=a.content)
228231
self.fs.add_file(b.file_path)
229232

230233
self.dl_results[b.url] = b.url_info
@@ -298,8 +301,8 @@ def test_download_and_extract(self):
298301
b.url: b.url_info,
299302
}
300303
)
301-
res = manager.download_and_extract({'a': a.url, 'b': b.url})
302-
self.assertEqual(res, {'a': a.extract_path, 'b': b.file_path})
304+
res = manager.download_and_extract({a.name: a.url, b.name: b.url})
305+
self.assertEqual(res, {a.name: a.extract_path, b.name: b.file_path})
303306

304307
def test_download_and_extract_no_manual_dir(self):
305308
a, b = Artifact('a.zip'), Artifact('b')
@@ -316,8 +319,8 @@ def test_download_and_extract_no_manual_dir(self):
316319
b.url: b.url_info,
317320
},
318321
)
319-
res = manager.download_and_extract({'a': a.url, 'b': b.url})
320-
self.assertEqual(res, {'a': a.extract_path, 'b': b.file_path})
322+
res = manager.download_and_extract({a.name: a.url, b.name: b.url})
323+
self.assertEqual(res, {a.name: a.extract_path, b.name: b.file_path})
321324

322325
def test_download_and_extract_archive_ext_in_fname(self):
323326
# Make sure extraction method is properly deduced from original fname, and
@@ -582,7 +585,7 @@ def test_register_checksums_url_info_already_exists(self):
582585

583586
def test_download_cached_url_path_checksum_updated(self):
584587
old_a = Artifact('a.tar.gz')
585-
new_a = Artifact('a.tar.gz', content='New a content') # New checksums
588+
new_a = Artifact('a.tar.gz', content='New content') # New checksums
586589

587590
# Urls are equals, but not checksums
588591
self.assertEqual(old_a.url, new_a.url)

tensorflow_datasets/core/download/downloader.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252

5353
@dataclasses.dataclass(eq=False, frozen=True)
5454
class DownloadResult:
55-
path: epath.Path | None
56-
url_info: checksums_lib.UrlInfo | None
55+
path: epath.Path
56+
url_info: checksums_lib.UrlInfo
5757

5858

5959
@utils.memoize()
@@ -80,7 +80,7 @@ def get_cached_path(
8080
checksum_path: epath.Path | None,
8181
url_path: epath.Path,
8282
expected_url_info: checksums_lib.UrlInfo | None,
83-
) -> DownloadResult:
83+
) -> DownloadResult | None:
8484
"""Returns the downloaded path and computed url-info.
8585
8686
If the path is not cached, or that `url_path` does not match checksums,
@@ -96,7 +96,10 @@ def get_cached_path(
9696
"""
9797
# User has manually downloaded the file.
9898
if manually_downloaded_path and manually_downloaded_path.exists():
99-
return DownloadResult(path=manually_downloaded_path, url_info=None)
99+
computed_url_info = checksums_lib.compute_url_info(manually_downloaded_path)
100+
return DownloadResult(
101+
path=manually_downloaded_path, url_info=computed_url_info
102+
)
100103

101104
# Download has been cached (checksum known)
102105
elif checksum_path and resource_lib.Resource.exists_locally(checksum_path):
@@ -110,13 +113,13 @@ def get_cached_path(
110113
# If checksums are now registered but do not match, trigger a new
111114
# download (e.g. previous file corrupted, checksums updated)
112115
if expected_url_info and computed_url_info != expected_url_info:
113-
return DownloadResult(path=None, url_info=None)
116+
return None
114117
else:
115118
return DownloadResult(path=url_path, url_info=computed_url_info)
116119

117120
# Else file not found (or has bad checksums). (re)download.
118121
else:
119-
return DownloadResult(path=None, url_info=None)
122+
return None
120123

121124

122125
def _filename_from_content_disposition(
@@ -216,13 +219,12 @@ def tqdm(self) -> Iterator[None]:
216219
self._pbar_dl_size = pbar_dl_size
217220
yield
218221

219-
def increase_tqdm(self, dl_result: DownloadResult) -> None:
220-
"""Update the tqdm bars to visually indicate the dl_result is downloaded."""
222+
def increase_tqdm(self, url_info: checksums_lib.UrlInfo) -> None:
223+
"""Update the tqdm bars to visually indicate the url_info is downloaded."""
221224
self._pbar_url.update_total(1)
222225
self._pbar_url.update(1)
223-
if dl_result.url_info: # Info unknown for manually downloaded files
224-
self._pbar_dl_size.update_total(dl_result.url_info.size)
225-
self._pbar_dl_size.update(dl_result.url_info.size)
226+
self._pbar_dl_size.update_total(url_info.size)
227+
self._pbar_dl_size.update(url_info.size)
226228

227229
def download(
228230
self, url: str, destination_path: epath.Path, verify: bool = True

tensorflow_datasets/testing/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _validate_out(self, out):
202202
def add_file(self, path, content=None) -> None:
203203
"""Add a file, creating all parent directories."""
204204
path = os.fspath(path)
205-
content = f'Content of {path}' if content is None else content
205+
content = content or f'Content of {path}'
206206
fpath = self._to_tmp(path)
207207
fpath.parent.mkdir(parents=True, exist_ok=True) # pytype: disable=attribute-error
208208
fpath.write_text(content) # pytype: disable=attribute-error

0 commit comments

Comments
 (0)