Skip to content

Commit a691e0a

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Fix typehints in download modules.
PiperOrigin-RevId: 672965431
1 parent 1b8b37a commit a691e0a

File tree

5 files changed

+76
-100
lines changed

5 files changed

+76
-100
lines changed

tensorflow_datasets/core/download/download_manager.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def replace(self, **kwargs: Any) -> DownloadConfig:
140140
return dataclasses.replace(self, **kwargs)
141141

142142

143-
class DownloadManager(object):
143+
class DownloadManager:
144144
"""Manages the download and extraction of files, as well as caching.
145145
146146
Downloaded files are cached under `download_dir`. The file name of downloaded
@@ -353,8 +353,9 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
353353
"""
354354
# Normalize the input
355355
if isinstance(resource, str):
356-
resource = resource_lib.Resource(url=resource)
357-
url = resource.url
356+
url = resource
357+
else:
358+
url = resource.url
358359
assert url is not None, 'URL is undefined from resource.'
359360

360361
expected_url_info = self._url_infos.get(url)
@@ -500,7 +501,7 @@ def _rename_and_get_final_dl_path(
500501
elif path == url_path:
501502
if checksum_path:
502503
# Checksums were registered: Rename -> checksums_path
503-
resource_lib.rename_info_file(path, checksum_path, overwrite=True)
504+
resource_lib.replace_info_file(path, checksum_path)
504505
return path.replace(checksum_path)
505506
else:
506507
# Checksums not registered: -> do nothing
@@ -522,7 +523,7 @@ def _rename_and_get_final_dl_path(
522523
@utils.memoize()
523524
def _extract(self, resource: ExtractPath) -> promise.Promise[epath.Path]:
524525
"""Extract a single archive, returns Promise->path to extraction result."""
525-
if isinstance(resource, epath.PathLikeCls):
526+
if not isinstance(resource, resource_lib.Resource):
526527
resource = resource_lib.Resource(path=resource)
527528
path = resource.path
528529
extract_method = resource.extract_method
@@ -613,7 +614,7 @@ def iter_archive(
613614
Returns:
614615
Generator yielding tuple (path_within_archive, file_obj).
615616
"""
616-
if isinstance(resource, epath.PathLikeCls):
617+
if not isinstance(resource, resource_lib.Resource):
617618
resource = resource_lib.Resource(path=resource)
618619
return extractor.iter_archive(resource.path, resource.extract_method)
619620

@@ -763,20 +764,6 @@ def _validate_checksums(
763764
raise NonMatchingChecksumError(msg)
764765

765766

766-
def _read_url_info(url_path: epath.PathLike) -> checksums.UrlInfo:
767-
"""Loads the `UrlInfo` from the `.INFO` file."""
768-
file_info = resource_lib.read_info_file(url_path)
769-
if 'url_info' not in file_info:
770-
raise ValueError(
771-
'Could not find `url_info` in {}. This likely indicates that '
772-
'the files where downloaded with a previous version of TFDS (<=3.1.0). '
773-
)
774-
url_info = file_info['url_info']
775-
url_info.setdefault('filename', None)
776-
url_info['size'] = utils.Size(url_info['size'])
777-
return checksums.UrlInfo(**url_info)
778-
779-
780767
def _map_promise(map_fn, all_inputs):
781768
"""Map the function into each element and resolve the promise."""
782769
all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function

tensorflow_datasets/core/download/download_manager_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __setitem__(self, key, value):
6464
return super().__setitem__(os.fspath(key), epath.Path(value))
6565

6666

67-
class Artifact(object):
67+
class Artifact:
6868
# For testing only.
6969

7070
def __init__(self, name, url=None, content=None):

tensorflow_datasets/core/download/downloader.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader':
6161
return _Downloader(*args, **kwargs)
6262

6363

64-
def _read_url_info(url_path: epath.PathLike) -> checksums_lib.UrlInfo:
64+
def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo:
6565
"""Loads the `UrlInfo` from the `.INFO` file."""
6666
file_info = resource_lib.read_info_file(url_path)
6767
if 'url_info' not in file_info:
@@ -172,7 +172,7 @@ def _get_filename(response: Response) -> str:
172172
return utils.basename_from_url(response.url)
173173

174174

175-
class _Downloader(object):
175+
class _Downloader:
176176
"""Class providing async download API with checksum validation.
177177
178178
Do not instantiate this class directly. Instead, call `get_downloader()`.
@@ -192,9 +192,8 @@ def __init__(
192192
"""Init _Downloader instance.
193193
194194
Args:
195-
max_simultaneous_downloads: `int`, optional max number of simultaneous
196-
downloads. If None then it defaults to
197-
`self._DEFAULT_MAX_SIMULTANEOUS_DOWNLOADS`.
195+
max_simultaneous_downloads: Optional max number of simultaneous downloads.
196+
If None then it defaults to `self._DEFAULT_MAX_SIMULTANEOUS_DOWNLOADS`.
198197
checksumer: `hashlib.HASH`. Defaults to `hashlib.sha256`.
199198
"""
200199
self._executor = concurrent.futures.ThreadPoolExecutor(
@@ -227,18 +226,18 @@ def increase_tqdm(self, dl_result: DownloadResult) -> None:
227226

228227
def download(
229228
self, url: str, destination_path: str, verify: bool = True
230-
) -> 'promise.Promise[concurrent.futures.Future[DownloadResult]]':
229+
) -> promise.Promise[concurrent.futures.Future[DownloadResult]]:
231230
"""Download url to given path.
232231
233232
Returns Promise -> sha256 of downloaded file.
234233
235234
Args:
236-
url: address of resource to download.
237-
destination_path: `str`, path to directory where to download the resource.
238-
verify: whether to verify ssl certificates
235+
url: Address of resource to download.
236+
destination_path: Path to directory where to download the resource.
237+
verify: Whether to verify ssl certificates
239238
240239
Returns:
241-
Promise obj -> (`str`, int): (downloaded object checksum, size in bytes).
240+
Promise obj -> Download result.
242241
"""
243242
destination_path = os.fspath(destination_path)
244243
self._pbar_url.update_total(1)
@@ -250,19 +249,19 @@ def download(
250249
def _sync_file_copy(
251250
self,
252251
filepath: str,
253-
destination_path: str,
252+
destination_path: epath.Path,
254253
) -> DownloadResult:
255254
"""Downloads the file through `tf.io.gfile` API."""
256255
filename = os.path.basename(filepath)
257-
out_path = os.path.join(destination_path, filename)
256+
out_path = destination_path / filename
258257
tf.io.gfile.copy(filepath, out_path)
259258
url_info = checksums_lib.compute_url_info(
260259
out_path, checksum_cls=self._checksumer_cls
261260
)
262261
self._pbar_dl_size.update_total(url_info.size)
263262
self._pbar_dl_size.update(url_info.size)
264263
self._pbar_url.update(1)
265-
return DownloadResult(path=epath.Path(out_path), url_info=url_info)
264+
return DownloadResult(path=out_path, url_info=url_info)
266265

267266
def _sync_download(
268267
self, url: str, destination_path: str, verify: bool = True
@@ -275,16 +274,17 @@ def _sync_download(
275274
https://requests.readthedocs.io/en/master/user/advanced/#proxies
276275
277276
Args:
278-
url: url to download
279-
destination_path: path where to write it
280-
verify: whether to verify ssl certificates
277+
url: Url to download.
278+
destination_path: Path where to write it.
279+
verify: Whether to verify ssl certificates.
281280
282281
Returns:
283-
None
282+
Download result.
284283
285284
Raises:
286285
DownloadError: when download fails.
287286
"""
287+
destination_path = epath.Path(destination_path)
288288
try:
289289
# If url is on a filesystem that gfile understands, use copy. Otherwise,
290290
# use requests (http) or urllib (ftp).
@@ -295,15 +295,15 @@ def _sync_download(
295295

296296
with _open_url(url, verify=verify) as (response, iter_content):
297297
fname = _get_filename(response)
298-
path = os.path.join(destination_path, fname)
298+
path = destination_path / fname
299299
size = 0
300300

301301
# Initialize the download size progress bar
302302
size_mb = 0
303303
unit_mb = units.MiB
304304
total_size = int(response.headers.get('Content-length', 0)) // unit_mb
305305
self._pbar_dl_size.update_total(total_size)
306-
with tf.io.gfile.GFile(path, 'wb') as file_:
306+
with path.open('wb') as file_:
307307
checksum = self._checksumer_cls()
308308
for block in iter_content:
309309
size += len(block)
@@ -317,7 +317,7 @@ def _sync_download(
317317
size_mb %= unit_mb
318318
self._pbar_url.update(1)
319319
return DownloadResult(
320-
path=epath.Path(path),
320+
path=path,
321321
url_info=checksums_lib.UrlInfo(
322322
checksum=checksum.hexdigest(),
323323
size=utils.Size(size),

tensorflow_datasets/core/download/downloader_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tensorflow_datasets.core.download import util
3131

3232

33-
class _FakeResponse(object):
33+
class _FakeResponse:
3434

3535
def __init__(self, url, content, cookies=None, headers=None, status_code=200):
3636
self.url = url

0 commit comments

Comments
 (0)