diff --git a/tensorflow_datasets/core/utils/py_utils.py b/tensorflow_datasets/core/utils/py_utils.py index c50ada43b21..81594e82b69 100644 --- a/tensorflow_datasets/core/utils/py_utils.py +++ b/tensorflow_datasets/core/utils/py_utils.py @@ -312,7 +312,7 @@ def _tmp_file_name( path: epath.PathLike, subfolder: str | None = None, ) -> epath.Path: - """Returns the temporary file name for the given path. + """Returns the temporary file path dependent on the given path and subfolder. Args: path: The path to the file. @@ -322,9 +322,12 @@ def _tmp_file_name( path = epath.Path(path) file_name = f'{_tmp_file_prefix()}.{path.name}' if subfolder: - return path.parent / subfolder / file_name + tmp_path = path.parent / subfolder / file_name else: - return path.parent / file_name + tmp_path = path.parent / file_name + # Create the parent directory if it doesn't exist. + tmp_path.parent.mkdir(parents=True, exist_ok=True) + return tmp_path @contextlib.contextmanager @@ -334,7 +337,6 @@ def incomplete_file( ) -> Iterator[epath.Path]: """Writes to path atomically, by writing to temp file and renaming it.""" tmp_path = _tmp_file_name(path, subfolder=subfolder) - tmp_path.parent.mkdir(exist_ok=True) try: yield tmp_path tmp_path.replace(path) @@ -346,20 +348,24 @@ def incomplete_file( @contextlib.contextmanager def incomplete_files( path: epath.Path, + subfolder: str | None = None, ) -> Iterator[epath.Path]: """Writes to path atomically, by writing to temp file and renaming it.""" - tmp_file_prefix = _tmp_file_prefix() - tmp_path = path.parent / f'{tmp_file_prefix}.{path.name}' + tmp_path = _tmp_file_name(path, subfolder=subfolder) + tmp_file_prefix = tmp_path.name.removesuffix(f'.{path.name}') try: yield tmp_path # Rename all tmp files to their final name. - for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'): + for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'): file_name = tmp_file.name.removeprefix(tmp_file_prefix + '.') tmp_file.replace(path.parent / file_name) finally: # Eventually delete the tmp_path if exception was raised - for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'): - tmp_file.unlink(missing_ok=True) + if subfolder: + tmp_path.parent.unlink(missing_ok=True) + else: + for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'): + tmp_file.unlink(missing_ok=True) def is_incomplete_file(path: epath.Path) -> bool: diff --git a/tensorflow_datasets/core/utils/py_utils_test.py b/tensorflow_datasets/core/utils/py_utils_test.py index 548bb4a43d2..8d774f621ed 100644 --- a/tensorflow_datasets/core/utils/py_utils_test.py +++ b/tensorflow_datasets/core/utils/py_utils_test.py @@ -373,15 +373,23 @@ def test_make_valid_name(name: str, expected: str): @pytest.mark.parametrize( - ['path', 'subfolder', 'expected'], + ['file_name', 'subfolder', 'expected_tmp_file_name'], [ - ('/a/file.ext', None, '/a/foobar.file.ext'), - ('/a/file.ext', 'sub', '/a/sub/foobar.file.ext'), + ('file.ext', None, 'foobar.file.ext'), + ('file.ext', 'sub', 'sub/foobar.file.ext'), ], ) -def test_tmp_file_name(path, subfolder, expected): +def test_tmp_file_name( + tmp_path: pathlib.Path, + file_name: str, + subfolder: str | None, + expected_tmp_file_name: str, +): with mock.patch.object(py_utils, '_tmp_file_prefix', return_value='foobar'): - assert os.fspath(py_utils._tmp_file_name(path, subfolder)) == expected + assert ( + py_utils._tmp_file_name(tmp_path / file_name, subfolder) + == tmp_path / expected_tmp_file_name + ) if __name__ == '__main__': diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index 45237d32e45..66651afc538 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -572,12 +572,14 @@ def _write_final_shard( shard_path = self._filename_template.sharded_filepath( shard_index=shard_id, num_shards=len(non_empty_shard_ids) ) - with utils.incomplete_files(epath.Path(shard_path)) as tmp_path: + with utils.incomplete_files(shard_path, subfolder="tmp") as tmp_shard_path: logging.info( - "Writing %d examples to %s.", len(example_by_key), os.fspath(tmp_path) + "Writing %d examples to %s.", + len(example_by_key), + os.fspath(tmp_shard_path), ) record_keys = self._example_writer.write( - tmp_path, sorted(example_by_key.items()) + tmp_shard_path, sorted(example_by_key.items()) ) self.inc_counter(name="written_shards") # If there are record_keys, create index files.