@@ -61,7 +61,7 @@ def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader':
61
61
return _Downloader (* args , ** kwargs )
62
62
63
63
64
- def _read_url_info (url_path : epath .PathLike ) -> checksums_lib .UrlInfo :
64
+ def _read_url_info (url_path : epath .Path ) -> checksums_lib .UrlInfo :
65
65
"""Loads the `UrlInfo` from the `.INFO` file."""
66
66
file_info = resource_lib .read_info_file (url_path )
67
67
if 'url_info' not in file_info :
@@ -172,7 +172,7 @@ def _get_filename(response: Response) -> str:
172
172
return utils .basename_from_url (response .url )
173
173
174
174
175
- class _Downloader ( object ) :
175
+ class _Downloader :
176
176
"""Class providing async download API with checksum validation.
177
177
178
178
Do not instantiate this class directly. Instead, call `get_downloader()`.
@@ -192,9 +192,8 @@ def __init__(
192
192
"""Init _Downloader instance.
193
193
194
194
Args:
195
- max_simultaneous_downloads: `int`, optional max number of simultaneous
196
- downloads. If None then it defaults to
197
- `self._DEFAULT_MAX_SIMULTANEOUS_DOWNLOADS`.
195
+ max_simultaneous_downloads: Optional max number of simultaneous downloads.
196
+ If None then it defaults to `self._DEFAULT_MAX_SIMULTANEOUS_DOWNLOADS`.
198
197
checksumer: `hashlib.HASH`. Defaults to `hashlib.sha256`.
199
198
"""
200
199
self ._executor = concurrent .futures .ThreadPoolExecutor (
@@ -227,18 +226,18 @@ def increase_tqdm(self, dl_result: DownloadResult) -> None:
227
226
228
227
def download (
229
228
self , url : str , destination_path : str , verify : bool = True
230
- ) -> ' promise.Promise[concurrent.futures.Future[DownloadResult]]' :
229
+ ) -> promise .Promise [concurrent .futures .Future [DownloadResult ]]:
231
230
"""Download url to given path.
232
231
233
232
Returns Promise -> sha256 of downloaded file.
234
233
235
234
Args:
236
- url: address of resource to download.
237
- destination_path: `str`, path to directory where to download the resource.
238
- verify: whether to verify ssl certificates
235
+ url: Address of resource to download.
236
+ destination_path: Path to directory where to download the resource.
237
+ verify: Whether to verify ssl certificates
239
238
240
239
Returns:
241
- Promise obj -> (`str`, int): (downloaded object checksum, size in bytes) .
240
+ Promise obj -> Download result .
242
241
"""
243
242
destination_path = os .fspath (destination_path )
244
243
self ._pbar_url .update_total (1 )
@@ -250,19 +249,19 @@ def download(
250
249
def _sync_file_copy (
251
250
self ,
252
251
filepath : str ,
253
- destination_path : str ,
252
+ destination_path : epath . Path ,
254
253
) -> DownloadResult :
255
254
"""Downloads the file through `tf.io.gfile` API."""
256
255
filename = os .path .basename (filepath )
257
- out_path = os . path . join ( destination_path , filename )
256
+ out_path = destination_path / filename
258
257
tf .io .gfile .copy (filepath , out_path )
259
258
url_info = checksums_lib .compute_url_info (
260
259
out_path , checksum_cls = self ._checksumer_cls
261
260
)
262
261
self ._pbar_dl_size .update_total (url_info .size )
263
262
self ._pbar_dl_size .update (url_info .size )
264
263
self ._pbar_url .update (1 )
265
- return DownloadResult (path = epath . Path ( out_path ) , url_info = url_info )
264
+ return DownloadResult (path = out_path , url_info = url_info )
266
265
267
266
def _sync_download (
268
267
self , url : str , destination_path : str , verify : bool = True
@@ -275,16 +274,17 @@ def _sync_download(
275
274
https://requests.readthedocs.io/en/master/user/advanced/#proxies
276
275
277
276
Args:
278
- url: url to download
279
- destination_path: path where to write it
280
- verify: whether to verify ssl certificates
277
+ url: Url to download.
278
+ destination_path: Path where to write it.
279
+ verify: Whether to verify ssl certificates.
281
280
282
281
Returns:
283
- None
282
+ Download result.
284
283
285
284
Raises:
286
285
DownloadError: when download fails.
287
286
"""
287
+ destination_path = epath .Path (destination_path )
288
288
try :
289
289
# If url is on a filesystem that gfile understands, use copy. Otherwise,
290
290
# use requests (http) or urllib (ftp).
@@ -295,15 +295,15 @@ def _sync_download(
295
295
296
296
with _open_url (url , verify = verify ) as (response , iter_content ):
297
297
fname = _get_filename (response )
298
- path = os . path . join ( destination_path , fname )
298
+ path = destination_path / fname
299
299
size = 0
300
300
301
301
# Initialize the download size progress bar
302
302
size_mb = 0
303
303
unit_mb = units .MiB
304
304
total_size = int (response .headers .get ('Content-length' , 0 )) // unit_mb
305
305
self ._pbar_dl_size .update_total (total_size )
306
- with tf . io . gfile . GFile ( path , 'wb' ) as file_ :
306
+ with path . open ( 'wb' ) as file_ :
307
307
checksum = self ._checksumer_cls ()
308
308
for block in iter_content :
309
309
size += len (block )
@@ -317,7 +317,7 @@ def _sync_download(
317
317
size_mb %= unit_mb
318
318
self ._pbar_url .update (1 )
319
319
return DownloadResult (
320
- path = epath . Path ( path ) ,
320
+ path = path ,
321
321
url_info = checksums_lib .UrlInfo (
322
322
checksum = checksum .hexdigest (),
323
323
size = utils .Size (size ),
0 commit comments