Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions tensorflow_datasets/core/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _tmp_file_name(
path: epath.PathLike,
subfolder: str | None = None,
) -> epath.Path:
"""Returns the temporary file name for the given path.
"""Returns the temporary file path dependent on the given path and subfolder.

Args:
path: The path to the file.
Expand All @@ -322,9 +322,12 @@ def _tmp_file_name(
path = epath.Path(path)
file_name = f'{_tmp_file_prefix()}.{path.name}'
if subfolder:
return path.parent / subfolder / file_name
tmp_path = path.parent / subfolder / file_name
else:
return path.parent / file_name
tmp_path = path.parent / file_name
# Create the parent directory if it doesn't exist.
tmp_path.parent.mkdir(parents=True, exist_ok=True)
return tmp_path


@contextlib.contextmanager
Expand All @@ -334,7 +337,6 @@ def incomplete_file(
) -> Iterator[epath.Path]:
"""Writes to path atomically, by writing to temp file and renaming it."""
tmp_path = _tmp_file_name(path, subfolder=subfolder)
tmp_path.parent.mkdir(exist_ok=True)
try:
yield tmp_path
tmp_path.replace(path)
Expand All @@ -346,20 +348,24 @@ def incomplete_file(
@contextlib.contextmanager
def incomplete_files(
path: epath.Path,
subfolder: str | None = None,
) -> Iterator[epath.Path]:
"""Writes to path atomically, by writing to temp file and renaming it."""
tmp_file_prefix = _tmp_file_prefix()
tmp_path = path.parent / f'{tmp_file_prefix}.{path.name}'
tmp_path = _tmp_file_name(path, subfolder=subfolder)
tmp_file_prefix = tmp_path.name.removesuffix(f'.{path.name}')
try:
yield tmp_path
# Rename all tmp files to their final name.
for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'):
for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'):
file_name = tmp_file.name.removeprefix(tmp_file_prefix + '.')
tmp_file.replace(path.parent / file_name)
finally:
# Eventually delete the tmp_path if exception was raised
for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'):
tmp_file.unlink(missing_ok=True)
if subfolder:
tmp_path.parent.unlink(missing_ok=True)
else:
for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'):
tmp_file.unlink(missing_ok=True)


def is_incomplete_file(path: epath.Path) -> bool:
Expand Down
18 changes: 13 additions & 5 deletions tensorflow_datasets/core/utils/py_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,23 @@ def test_make_valid_name(name: str, expected: str):


@pytest.mark.parametrize(
['path', 'subfolder', 'expected'],
['file_name', 'subfolder', 'expected_tmp_file_name'],
[
('/a/file.ext', None, '/a/foobar.file.ext'),
('/a/file.ext', 'sub', '/a/sub/foobar.file.ext'),
('file.ext', None, 'foobar.file.ext'),
('file.ext', 'sub', 'sub/foobar.file.ext'),
],
)
def test_tmp_file_name(path, subfolder, expected):
def test_tmp_file_name(
tmp_path: pathlib.Path,
file_name: str,
subfolder: str | None,
expected_tmp_file_name: str,
):
with mock.patch.object(py_utils, '_tmp_file_prefix', return_value='foobar'):
assert os.fspath(py_utils._tmp_file_name(path, subfolder)) == expected
assert (
py_utils._tmp_file_name(tmp_path / file_name, subfolder)
== tmp_path / expected_tmp_file_name
)


if __name__ == '__main__':
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,12 +572,14 @@ def _write_final_shard(
shard_path = self._filename_template.sharded_filepath(
shard_index=shard_id, num_shards=len(non_empty_shard_ids)
)
with utils.incomplete_files(epath.Path(shard_path)) as tmp_path:
with utils.incomplete_files(shard_path, subfolder="tmp") as tmp_shard_path:
logging.info(
"Writing %d examples to %s.", len(example_by_key), os.fspath(tmp_path)
"Writing %d examples to %s.",
len(example_by_key),
os.fspath(tmp_shard_path),
)
record_keys = self._example_writer.write(
tmp_path, sorted(example_by_key.items())
tmp_shard_path, sorted(example_by_key.items())
)
self.inc_counter(name="written_shards")
# If there are record_keys, create index files.
Expand Down