Skip to content

Commit 7c3d7a6

Browse files
author
The TensorFlow Datasets Authors
committed
Fix incompatibility between MultiSplitInfo/SubSplitInfo and EvenSplit
PiperOrigin-RevId: 728628832
1 parent 281ce2d commit 7c3d7a6

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

tensorflow_datasets/core/splits.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,18 @@ def _file_instructions_for_split(
536536
)
537537
return []
538538
to = split_info.num_examples if instruction.to is None else instruction.to
539+
if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)):
540+
examples_in_shards = [
541+
f.examples_in_shard for f in split_info.file_instructions
542+
]
543+
else:
544+
examples_in_shards = None
539545
return shard_utils.get_file_instructions(
540546
from_=instruction.from_ or 0,
541547
to=to,
542548
filenames=[os.fspath(fp) for fp in split_info.filepaths],
543549
shard_lengths=split_info.shard_lengths,
550+
examples_in_shards=examples_in_shards,
544551
)
545552

546553

tensorflow_datasets/core/splits_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,44 @@ def split_info_for(name: str, shard_lengths, template) -> splits.SplitInfo:
217217
assert merged.get('test').split_infos == [split_info_a2, split_info_b2]
218218
assert merged.get('banana').split_infos == [split_info_a3]
219219

220+
def test_multi_split_sub_split(self):
221+
split_info = splits.MultiSplitInfo(
222+
name='train',
223+
split_infos=[
224+
splits.SubSplitInfo(
225+
name='train[:2]',
226+
file_instructions=[
227+
shard_utils.FileInstruction(
228+
filename='/a/file-00000-of-00001',
229+
skip=0,
230+
take=2,
231+
examples_in_shard=10,
232+
)
233+
],
234+
),
235+
splits.SubSplitInfo(
236+
name='train[:10]',
237+
file_instructions=[
238+
shard_utils.FileInstruction(
239+
filename='/b/file-00000-of-00001',
240+
skip=0,
241+
take=10,
242+
examples_in_shard=20,
243+
)
244+
],
245+
),
246+
],
247+
)
248+
split_dict = splits.SplitDict([split_info])
249+
sub_split = split_dict['train[:2]']
250+
self.assertEqual(sub_split.name, 'train[:2]')
251+
self.assertLen(sub_split.file_instructions, 1)
252+
file_instruction = sub_split.file_instructions[0]
253+
self.assertEqual(file_instruction.filename, '/a/file-00000-of-00001')
254+
self.assertEqual(file_instruction.skip, 0)
255+
self.assertEqual(file_instruction.take, 2)
256+
self.assertEqual(file_instruction.examples_in_shard, 10)
257+
220258

221259
class SplitsTest(testing.TestCase):
222260

tensorflow_datasets/core/utils/shard_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def get_file_instructions(
216216
to: int,
217217
filenames: Sequence[str],
218218
shard_lengths: Sequence[int],
219+
examples_in_shards: Sequence[int] | None = None,
219220
) -> list[FileInstruction]:
220221
"""Returns a list of files (+skip/take) to read [from_:to] items from shards.
221222
@@ -225,14 +226,18 @@ def get_file_instructions(
225226
filenames: list of strings or ints, the filenames of the shards. Not really
226227
used, but to place in result.
227228
shard_lengths: the number of elements in every shard.
229+
examples_in_shards: the number of examples in every shard. If not provided,
230+
then `shard_lengths` is used.
228231
229232
Returns:
230233
list of dict(filename, skip, take).
231234
"""
232235
index_start = 0 # Beginning (included) of moving window.
233236
index_end = 0 # End (excluded) of moving window.
234237
file_instructions = []
235-
for filename, length in zip(filenames, shard_lengths):
238+
for shard_index, (filename, length) in enumerate(
239+
zip(filenames, shard_lengths)
240+
):
236241
if not length:
237242
continue # Empty shard - can happen with temporary buckets.
238243
index_end += length
@@ -241,9 +246,18 @@ def get_file_instructions(
241246
take = to - index_start - skip if to < index_end else -1
242247
if take == 0:
243248
continue
249+
if examples_in_shards is not None:
250+
examples_in_shard = examples_in_shards[shard_index]
251+
if take == -1 and examples_in_shard != length:
252+
take = length
253+
else:
254+
examples_in_shard = length
244255
file_instructions.append(
245256
FileInstruction(
246-
filename=filename, skip=skip, take=take, examples_in_shard=length
257+
filename=filename,
258+
skip=skip,
259+
take=take,
260+
examples_in_shard=examples_in_shard,
247261
)
248262
)
249263
index_start += length

0 commit comments

Comments
 (0)