Skip to content

Commit 43c7381

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Simplify resource.write_info_file
* Use `dict` instead of `Mapping` for `Json` type. * Use `_add_value_to_list()` and `_set_value()` to update `info` dict. * Use `epath.Path.replace()` instead of `unlink()` and `rename()`. PiperOrigin-RevId: 683157254
1 parent 7b33592 commit 43c7381

File tree

3 files changed

+27
-36
lines changed

3 files changed

+27
-36
lines changed

tensorflow_datasets/core/download/download_manager_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def test_download_url_info_in_info_file_missmatch(self):
553553
register_checksums=False,
554554
force_download=True,
555555
)
556-
with self.assertRaisesRegex(ValueError, 'contains a different checksum'):
556+
with self.assertRaisesRegex(ValueError, 'contains a different "url_info"'):
557557
dl_manager.download(a.url)
558558

559559
# If the url is re-downloaded with the same hash, no error is raised

tensorflow_datasets/core/download/resource.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import base64
1919
import codecs
20-
from collections.abc import Mapping
2120
import enum
2221
import itertools
2322
import json
@@ -30,8 +29,7 @@
3029
from tensorflow_datasets.core import utils
3130
from tensorflow_datasets.core.download import checksums as checksums_lib
3231

33-
# Should be `Union[int, float, bool, str, Dict[str, Json], List[Json]]`
34-
Json = Mapping[str, Any]
32+
Json = dict[str, Any]
3533

3634
_hex_codec = codecs.getdecoder('hex_codec')
3735

@@ -205,9 +203,7 @@ def is_locally_cached(path: epath.Path) -> bool:
205203

206204
def _read_info(info_path: epath.Path) -> Json:
207205
"""Returns info dict."""
208-
if not info_path.exists():
209-
return {}
210-
return json.loads(info_path.read_text())
206+
return json.loads(info_path.read_text()) if info_path.exists() else {}
211207

212208

213209
# TODO(pierrot): one lock per info path instead of locking everything.
@@ -223,6 +219,22 @@ def read_info_file(info_path: epath.Path) -> Json:
223219
return _read_info(get_info_path(info_path))
224220

225221

222+
def _add_value_to_list(info: Json, key: str, value: str) -> None:
223+
"""Adds `value` to list `key` in `info` dict."""
224+
if value and value not in (stored_values := info.get(key, [])):
225+
info[key] = stored_values + [value]
226+
227+
228+
def _set_value(info_path: epath.Path, info: Json, key: str, value: Any) -> None:
229+
"""Sets `value` to `key` in `info` dict."""
230+
if (stored_value := info.get(key)) and stored_value != value:
231+
raise ValueError(
232+
f'File info "{info_path}" contains a different "{key}" than the'
233+
f' downloaded one: Stored: {stored_value}; Expected: {value}'
234+
)
235+
info[key] = value
236+
237+
226238
@synchronize_decorator
227239
def write_info_file(
228240
url: str,
@@ -244,40 +256,21 @@ def write_info_file(
244256
original_fname: name of file as downloaded.
245257
url_info: checksums/size info of the url
246258
"""
247-
url_info_dict = url_info.asdict()
248259
info_path = get_info_path(path)
249260
info = _read_info(info_path)
250-
urls = set(info.get('urls', []) + [url])
251-
dataset_names = info.get('dataset_names', [])
252-
if dataset_name:
253-
dataset_names.append(dataset_name)
254-
if info.get('original_fname', original_fname) != original_fname:
255-
raise ValueError(
256-
'`original_fname` "{}" stored in {} does NOT match "{}".'.format(
257-
info['original_fname'], info_path, original_fname
258-
)
259-
)
260-
if info.get('url_info', url_info_dict) != url_info_dict:
261-
raise ValueError(
262-
'File info {} contains a different checksum that the downloaded one: '
263-
'Stored: {}; Expected: {}'.format(
264-
info_path, info['url_info'], url_info_dict
265-
)
266-
)
267-
info = dict(
268-
urls=list(urls),
269-
dataset_names=list(set(dataset_names)),
270-
original_fname=original_fname,
271-
url_info=url_info_dict,
272-
)
261+
262+
_add_value_to_list(info, 'urls', url)
263+
_add_value_to_list(info, 'dataset_names', dataset_name)
264+
_set_value(info_path, info, 'original_fname', original_fname)
265+
_set_value(info_path, info, 'url_info', url_info.asdict())
266+
273267
with utils.atomic_write(info_path, 'w') as info_f:
274268
json.dump(info, info_f, sort_keys=True)
275269

276270

277271
def _get_extract_method(path: epath.Path) -> ExtractMethod:
278272
"""Returns `ExtractMethod` to use on resource at path. Cannot be None."""
279-
info_path = get_info_path(path)
280-
info = _read_info(info_path)
273+
info = _read_info(get_info_path(path))
281274
fname = info.get('original_fname', os.fspath(path))
282275
return guess_extract_method(fname)
283276

tensorflow_datasets/core/utils/py_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,10 @@ def is_incomplete_file(path: epath.Path) -> bool:
361361
@contextlib.contextmanager
362362
def atomic_write(path: epath.PathLike, mode: str):
363363
"""Writes to path atomically, by writing to temp file and renaming it."""
364-
path = epath.Path(path)
365364
tmp_path = _tmp_file_name(path)
366365
with tmp_path.open(mode=mode) as file_:
367366
yield file_
368-
path.unlink(missing_ok=True)
369-
tmp_path.rename(path)
367+
tmp_path.replace(path)
370368

371369

372370
def reraise(

0 commit comments

Comments
 (0)