Skip to content

Commit 19ee88d

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Add subfolder option to incomplete_files and use it in writer.py.
PiperOrigin-RevId: 804868748
1 parent 57cad96 commit 19ee88d

File tree

3 files changed

+37
-18
lines changed

3 files changed

+37
-18
lines changed

tensorflow_datasets/core/utils/py_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _tmp_file_name(
312312
path: epath.PathLike,
313313
subfolder: str | None = None,
314314
) -> epath.Path:
315-
"""Returns the temporary file name for the given path.
315+
"""Returns the temporary file path dependent on the given path and subfolder.
316316
317317
Args:
318318
path: The path to the file.
@@ -322,9 +322,12 @@ def _tmp_file_name(
322322
path = epath.Path(path)
323323
file_name = f'{_tmp_file_prefix()}.{path.name}'
324324
if subfolder:
325-
return path.parent / subfolder / file_name
325+
tmp_path = path.parent / subfolder / file_name
326326
else:
327-
return path.parent / file_name
327+
tmp_path = path.parent / file_name
328+
# Create the parent directory if it doesn't exist.
329+
tmp_path.parent.mkdir(parents=True, exist_ok=True)
330+
return tmp_path
328331

329332

330333
@contextlib.contextmanager
@@ -334,32 +337,38 @@ def incomplete_file(
334337
) -> Iterator[epath.Path]:
335338
"""Writes to path atomically, by writing to temp file and renaming it."""
336339
tmp_path = _tmp_file_name(path, subfolder=subfolder)
337-
tmp_path.parent.mkdir(exist_ok=True)
338340
try:
339341
yield tmp_path
340342
tmp_path.replace(path)
341343
finally:
342344
# Eventually delete the tmp_path if exception was raised
343-
tmp_path.unlink(missing_ok=True)
345+
if subfolder:
346+
tmp_path.parent.unlink(missing_ok=True)
347+
else:
348+
tmp_path.unlink(missing_ok=True)
344349

345350

346351
@contextlib.contextmanager
347352
def incomplete_files(
348353
path: epath.Path,
354+
subfolder: str | None = None,
349355
) -> Iterator[epath.Path]:
350356
"""Writes to path atomically, by writing to temp file and renaming it."""
351-
tmp_file_prefix = _tmp_file_prefix()
352-
tmp_path = path.parent / f'{tmp_file_prefix}.{path.name}'
357+
tmp_path = _tmp_file_name(path, subfolder=subfolder)
358+
tmp_file_prefix = tmp_path.name.removesuffix(f'.{path.name}')
353359
try:
354360
yield tmp_path
355361
# Rename all tmp files to their final name.
356-
for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'):
362+
for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'):
357363
file_name = tmp_file.name.removeprefix(tmp_file_prefix + '.')
358364
tmp_file.replace(path.parent / file_name)
359365
finally:
360366
# Eventually delete the tmp_path if exception was raised
361-
for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'):
362-
tmp_file.unlink(missing_ok=True)
367+
if subfolder:
368+
tmp_path.parent.unlink(missing_ok=True)
369+
else:
370+
for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'):
371+
tmp_file.unlink(missing_ok=True)
363372

364373

365374
def is_incomplete_file(path: epath.Path) -> bool:

tensorflow_datasets/core/utils/py_utils_test.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,23 @@ def test_make_valid_name(name: str, expected: str):
373373

374374

375375
@pytest.mark.parametrize(
376-
['path', 'subfolder', 'expected'],
376+
['file_name', 'subfolder', 'expected_tmp_file_name'],
377377
[
378-
('/a/file.ext', None, '/a/foobar.file.ext'),
379-
('/a/file.ext', 'sub', '/a/sub/foobar.file.ext'),
378+
('file.ext', None, 'foobar.file.ext'),
379+
('file.ext', 'sub', 'sub/foobar.file.ext'),
380380
],
381381
)
382-
def test_tmp_file_name(path, subfolder, expected):
382+
def test_tmp_file_name(
383+
tmp_path: pathlib.Path,
384+
file_name: str,
385+
subfolder: str | None,
386+
expected_tmp_file_name: str,
387+
):
383388
with mock.patch.object(py_utils, '_tmp_file_prefix', return_value='foobar'):
384-
assert os.fspath(py_utils._tmp_file_name(path, subfolder)) == expected
389+
assert (
390+
py_utils._tmp_file_name(tmp_path / file_name, subfolder)
391+
== tmp_path / expected_tmp_file_name
392+
)
385393

386394

387395
if __name__ == '__main__':

tensorflow_datasets/core/writer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,12 +572,14 @@ def _write_final_shard(
572572
shard_path = self._filename_template.sharded_filepath(
573573
shard_index=shard_id, num_shards=len(non_empty_shard_ids)
574574
)
575-
with utils.incomplete_files(epath.Path(shard_path)) as tmp_path:
575+
with utils.incomplete_files(shard_path, subfolder="tmp") as tmp_shard_path:
576576
logging.info(
577-
"Writing %d examples to %s.", len(example_by_key), os.fspath(tmp_path)
577+
"Writing %d examples to %s.",
578+
len(example_by_key),
579+
os.fspath(tmp_shard_path),
578580
)
579581
record_keys = self._example_writer.write(
580-
tmp_path, sorted(example_by_key.items())
582+
tmp_shard_path, sorted(example_by_key.items())
581583
)
582584
self.inc_counter(name="written_shards")
583585
# If there are record_keys, create index files.

0 commit comments

Comments
 (0)