Skip to content

Commit fc31737

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Use epath.Path in downloader.py
PiperOrigin-RevId: 676837944
1 parent 3b0dab2 commit fc31737

File tree

3 files changed

+75
-87
lines changed

3 files changed

+75
-87
lines changed

tensorflow_datasets/core/download/download_manager.py

Lines changed: 58 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,21 @@ def __getstate__(self):
297297
return state
298298

299299
@property
300-
def _downloader(self):
300+
def _downloader(self) -> downloader._Downloader:
301301
if not self.__downloader:
302302
self.__downloader = get_downloader(
303303
max_simultaneous_downloads=self._max_simultaneous_downloads
304304
)
305305
return self.__downloader
306306

307307
@property
308-
def _extractor(self):
308+
def _extractor(self) -> extractor._Extractor:
309309
if not self.__extractor:
310310
self.__extractor = extractor.get_extractor()
311311
return self.__extractor
312312

313313
@property
314-
def downloaded_size(self):
314+
def downloaded_size(self) -> int:
315315
"""Returns the total size of downloaded files."""
316316
return sum(url_info.size for url_info in self._recorded_url_infos.values())
317317

@@ -331,6 +331,22 @@ def _record_url_infos(self):
331331
self._recorded_url_infos,
332332
)
333333

334+
def _get_manually_downloaded_path(
335+
self, expected_url_info: checksums.UrlInfo | None
336+
) -> epath.Path | None:
337+
"""Checks if file is already downloaded in manual_dir."""
338+
if not self._manual_dir: # Manual dir not passed
339+
return None
340+
341+
if not expected_url_info or not expected_url_info.filename:
342+
return None # Filename unknown.
343+
344+
manual_path = self._manual_dir / expected_url_info.filename
345+
if not manual_path.exists(): # File not manually downloaded
346+
return None
347+
348+
return manual_path
349+
334350
# Synchronize and memoize decorators ensure same resource will only be
335351
# processed once, even if passed twice to download_manager.
336352
@utils.build_synchronize_decorator()
@@ -363,9 +379,8 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
363379
# * In `manual_dir` (manually downloaded data)
364380
# * In `downloads/url_path` (checksum unknown)
365381
# * In `downloads/checksum_path` (checksum registered)
366-
manually_downloaded_path = _get_manually_downloaded_path(
367-
manual_dir=self._manual_dir,
368-
expected_url_info=expected_url_info,
382+
manually_downloaded_path = self._get_manually_downloaded_path(
383+
expected_url_info=expected_url_info
369384
)
370385
url_path = self._get_dl_path(url)
371386
checksum_path = (
@@ -459,12 +474,11 @@ def _register_or_validate_checksums(
459474
# the download isn't cached (re-running build will retrigger a new
460475
# download). This is expected as it might mean the downloaded file
461476
# was corrupted. Note: The tmp file isn't deleted to allow inspection.
462-
_validate_checksums(
477+
self._validate_checksums(
463478
url=url,
464479
path=path,
465480
expected_url_info=expected_url_info,
466481
computed_url_info=computed_url_info,
467-
force_checksums_validation=self._force_checksums_validation,
468482
)
469483

470484
return self._rename_and_get_final_dl_path(
@@ -476,6 +490,42 @@ def _register_or_validate_checksums(
476490
url_path=url_path,
477491
)
478492

493+
def _validate_checksums(
494+
self,
495+
url: str,
496+
path: epath.Path,
497+
computed_url_info: checksums.UrlInfo | None,
498+
expected_url_info: checksums.UrlInfo | None,
499+
) -> None:
500+
"""Validate computed_url_info match expected_url_info."""
501+
# If force-checksums validations, both expected and computed url_info
502+
# should exists
503+
if self._force_checksums_validation:
504+
# Checksum of the downloaded file unknown (for manually downloaded file)
505+
if not computed_url_info:
506+
computed_url_info = checksums.compute_url_info(path)
507+
# Checksums have not been registered
508+
if not expected_url_info:
509+
raise ValueError(
510+
f'Missing checksums url: {url}, yet '
511+
'`force_checksums_validation=True`. '
512+
'Did you forget to register checksums?'
513+
)
514+
515+
if (
516+
expected_url_info
517+
and computed_url_info
518+
and expected_url_info != computed_url_info
519+
):
520+
msg = (
521+
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
522+
f'* Expected: {expected_url_info}\n'
523+
f'* Got: {computed_url_info}\n'
524+
'To debug, see: '
525+
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
526+
)
527+
raise NonMatchingChecksumError(msg)
528+
479529
def _rename_and_get_final_dl_path(
480530
self,
481531
url: str,
@@ -707,61 +757,6 @@ def manual_dir(self) -> epath.Path:
707757
return self._manual_dir
708758

709759

710-
def _get_manually_downloaded_path(
711-
manual_dir: epath.Path | None,
712-
expected_url_info: checksums.UrlInfo | None,
713-
) -> epath.Path | None:
714-
"""Checks if file is already downloaded in manual_dir."""
715-
if not manual_dir: # Manual dir not passed
716-
return None
717-
718-
if not expected_url_info or not expected_url_info.filename:
719-
return None # Filename unknown.
720-
721-
manual_path = manual_dir / expected_url_info.filename
722-
if not manual_path.exists(): # File not manually downloaded
723-
return None
724-
725-
return manual_path
726-
727-
728-
def _validate_checksums(
729-
url: str,
730-
path: epath.Path,
731-
computed_url_info: checksums.UrlInfo | None,
732-
expected_url_info: checksums.UrlInfo | None,
733-
force_checksums_validation: bool,
734-
) -> None:
735-
"""Validate computed_url_info match expected_url_info."""
736-
# If force-checksums validations, both expected and computed url_info
737-
# should exists
738-
if force_checksums_validation:
739-
# Checksum of the downloaded file unknown (for manually downloaded file)
740-
if not computed_url_info:
741-
computed_url_info = checksums.compute_url_info(path)
742-
# Checksums have not been registered
743-
if not expected_url_info:
744-
raise ValueError(
745-
f'Missing checksums url: {url}, yet '
746-
'`force_checksums_validation=True`. '
747-
'Did you forget to register checksums?'
748-
)
749-
750-
if (
751-
expected_url_info
752-
and computed_url_info
753-
and expected_url_info != computed_url_info
754-
):
755-
msg = (
756-
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
757-
f'* Expected: {expected_url_info}\n'
758-
f'* Got: {computed_url_info}\n'
759-
'To debug, see: '
760-
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
761-
)
762-
raise NonMatchingChecksumError(msg)
763-
764-
765760
def _map_promise(map_fn, all_inputs):
766761
"""Map the function into each element and resolve the promise."""
767762
all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function

tensorflow_datasets/core/download/downloader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def increase_tqdm(self, dl_result: DownloadResult) -> None:
225225
self._pbar_dl_size.update(dl_result.url_info.size)
226226

227227
def download(
228-
self, url: str, destination_path: str, verify: bool = True
228+
self, url: str, destination_path: epath.Path, verify: bool = True
229229
) -> promise.Promise[concurrent.futures.Future[DownloadResult]]:
230230
"""Download url to given path.
231231
@@ -239,7 +239,6 @@ def download(
239239
Returns:
240240
Promise obj -> Download result.
241241
"""
242-
destination_path = os.fspath(destination_path)
243242
self._pbar_url.update_total(1)
244243
future = self._executor.submit(
245244
self._sync_download, url, destination_path, verify
@@ -264,7 +263,7 @@ def _sync_file_copy(
264263
return DownloadResult(path=out_path, url_info=url_info)
265264

266265
def _sync_download(
267-
self, url: str, destination_path: str, verify: bool = True
266+
self, url: str, destination_path: epath.Path, verify: bool = True
268267
) -> DownloadResult:
269268
"""Synchronous version of `download` method.
270269
@@ -284,7 +283,6 @@ def _sync_download(
284283
Raises:
285284
DownloadError: when download fails.
286285
"""
287-
destination_path = epath.Path(destination_path)
288286
try:
289287
# If url is on a filesystem that gfile understands, use copy. Otherwise,
290288
# use requests (http) or urllib (ftp).

tensorflow_datasets/core/download/downloader_test.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for downloader."""
17-
1816
import hashlib
1917
import io
20-
import os
21-
import tempfile
2218
from typing import Optional
2319
from unittest import mock
2420

21+
from etils import epath
2522
import pytest
26-
import tensorflow as tf
2723
from tensorflow_datasets import testing
2824
from tensorflow_datasets.core.download import downloader
2925
from tensorflow_datasets.core.download import resource as resource_lib
@@ -59,11 +55,13 @@ def setUp(self):
5955
super(DownloaderTest, self).setUp()
6056
self.addCleanup(mock.patch.stopall)
6157
self.downloader = downloader.get_downloader(10, hashlib.sha256)
62-
self.tmp_dir = tempfile.mkdtemp(dir=tf.compat.v1.test.get_temp_dir())
58+
self.tmp_dir = epath.Path(self.tmp_dir)
6359
self.url = 'http://example.com/foo.tar.gz'
6460
self.resource = resource_lib.Resource(url=self.url)
65-
self.path = os.path.join(self.tmp_dir, 'foo.tar.gz')
66-
self.incomplete_path = '%s.incomplete' % self.path
61+
self.path = self.tmp_dir / 'foo.tar.gz'
62+
self.incomplete_path = self.path.with_suffix(
63+
self.path.suffix + '.incomplete'
64+
)
6765
self.response = b'This \nis an \nawesome\n response!'
6866
self.resp_checksum = hashlib.sha256(self.response).hexdigest()
6967
self.cookies = {}
@@ -84,22 +82,20 @@ def test_ok(self):
8482
promise = self.downloader.download(self.url, self.tmp_dir)
8583
future = promise.get()
8684
url_info = future.url_info
87-
self.assertEqual(self.path, os.fspath(future.path))
85+
self.assertEqual(self.path, future.path)
8886
self.assertEqual(url_info.checksum, self.resp_checksum)
89-
with tf.io.gfile.GFile(self.path, 'rb') as result:
90-
self.assertEqual(result.read(), self.response)
91-
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
87+
self.assertEqual(self.path.read_bytes(), self.response)
88+
self.assertFalse(self.incomplete_path.exists())
9289

9390
def test_drive_no_cookies(self):
9491
url = 'https://drive.google.com/uc?export=download&id=a1b2bc3'
9592
promise = self.downloader.download(url, self.tmp_dir)
9693
future = promise.get()
9794
url_info = future.url_info
98-
self.assertEqual(self.path, os.fspath(future.path))
95+
self.assertEqual(self.path, future.path)
9996
self.assertEqual(url_info.checksum, self.resp_checksum)
100-
with tf.io.gfile.GFile(self.path, 'rb') as result:
101-
self.assertEqual(result.read(), self.response)
102-
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
97+
self.assertEqual(self.path.read_bytes(), self.response)
98+
self.assertFalse(self.incomplete_path.exists())
10399

104100
def test_drive(self):
105101
self.cookies = {'foo': 'bar', 'download_warning_a': 'token', 'a': 'b'}
@@ -129,11 +125,10 @@ def test_ftp(self):
129125
promise = self.downloader.download(url, self.tmp_dir)
130126
future = promise.get()
131127
url_info = future.url_info
132-
self.assertEqual(self.path, os.fspath(future.path))
128+
self.assertEqual(self.path, future.path)
133129
self.assertEqual(url_info.checksum, self.resp_checksum)
134-
with tf.io.gfile.GFile(self.path, 'rb') as result:
135-
self.assertEqual(result.read(), self.response)
136-
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
130+
self.assertEqual(self.path.read_bytes(), self.response)
131+
self.assertFalse(self.incomplete_path.exists())
137132

138133
def test_ftp_error(self):
139134
error = downloader.urllib.error.URLError('Problem serving file.')

0 commit comments

Comments
 (0)