Skip to content

Commit 9982226

Browse files
author
The TensorFlow Datasets Authors
committed
Handle empty defined splits rather than raise error.
If the split is defined as empty in the metadata file, then treat the split as empty. PiperOrigin-RevId: 649076154
1 parent 4fd6d59 commit 9982226

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

tensorflow_datasets/core/splits.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,11 @@ def _file_instructions_for_split(
476476
) -> List[shard_utils.FileInstruction]:
477477
"""Returns the file instructions from the given instruction applied to the given split info."""
478478
if not split_info.num_examples:
479-
raise ValueError(
480-
"Shard empty. This might means that dataset hasn't been generated "
481-
'yet and info not restored from GCS, or that legacy dataset is used.'
479+
logging.warning(
480+
'Split %s has no examples. Skipping file instructions.',
481+
split_info.name,
482482
)
483+
return []
483484
to = split_info.num_examples if instruction.to is None else instruction.to
484485
return shard_utils.get_file_instructions(
485486
from_=instruction.from_ or 0,

tensorflow_datasets/core/splits_test.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -627,20 +627,18 @@ def test_touching_boundaries(self):
627627
self.assertEqual(files, [])
628628

629629
def test_missing_shard_lengths(self):
630-
with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'):
631-
filename_template = _filename_template(
632-
split='train', dataset_name='mnist'
633-
)
634-
split_infos = [
635-
splits.SplitInfo(
636-
name='train',
637-
shard_lengths=[],
638-
num_bytes=0,
639-
filename_template=filename_template,
640-
),
641-
]
642-
splits_dict = splits.SplitDict(split_infos=split_infos)
643-
_ = splits_dict['train'].file_instructions
630+
filename_template = _filename_template(split='train', dataset_name='mnist')
631+
split_infos = [
632+
splits.SplitInfo(
633+
name='train',
634+
shard_lengths=[],
635+
num_bytes=0,
636+
filename_template=filename_template,
637+
),
638+
]
639+
splits_dict = splits.SplitDict(split_infos=split_infos)
640+
files = splits_dict['train'].file_instructions
641+
self.assertEqual(files, [])
644642

645643

646644
if __name__ == '__main__':

0 commit comments

Comments
 (0)