Skip to content

Commit 2638dc2

Browse files
author
The TensorFlow Datasets Authors
committed
Add the test split to TAO dataset.
PiperOrigin-RevId: 681794893
1 parent 4a42fe1 commit 2638dc2

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

tensorflow_datasets/video/tao/tao.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _maybe_prepare_manual_data(
101101
manually_downloaded_files = [
102102
'1_AVA_HACS_TRAIN_*.zip',
103103
'2_AVA_HACS_VAL_*.zip',
104+
'3_AVA_HACS_TEST_*.zip',
104105
]
105106
files = []
106107
for file in manually_downloaded_files:
@@ -282,7 +283,7 @@ class Tao(tfds.core.BeamBasedBuilder):
282283
]
283284
VERSION = tfds.core.Version('1.0.0')
284285
RELEASE_NOTES = {
285-
'1.0.0': 'Initial release.',
286+
'1.1.0': 'Added test split.',
286287
}
287288

288289
def _info(self) -> tfds.core.DatasetInfo:
@@ -336,6 +337,7 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
336337
data = dl_manager.download_and_extract({
337338
'train': _VIDEO_URL + '1-TAO_TRAIN.zip',
338339
'val': _VIDEO_URL + '2-TAO_VAL.zip',
340+
'test': _VIDEO_URL + '3-TAO_TEST.zip',
339341
'annotations': _ANNOTATIONS_URL,
340342
})
341343

@@ -359,6 +361,14 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
359361
/ 'validation.json',
360362
id_map=id_map,
361363
),
364+
tfds.Split.TEST: self._generate_examples(
365+
data_path=data['test'],
366+
manual_path=None,
367+
annotations_path=data['annotations']
368+
/ 'annotations-1.2'
369+
/ 'test_without_annotations.json',
370+
id_map=id_map,
371+
),
362372
}
363373

364374
def _maybe_resize_video(self, frames_list):

tensorflow_datasets/video/tao/tao_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class TaoTest(tfds.testing.DatasetBuilderTestCase):
2727
SPLITS = {
2828
tfds.Split.TRAIN: 1,
2929
tfds.Split.VALIDATION: 1,
30+
tfds.Split.TEST: 1,
3031
}
3132
DL_EXTRACT_RESULT = {
3233
'train': '',
3334
'val': '',
35+
'test': '',
3436
'annotations': '',
3537
}
3638

@@ -58,7 +60,8 @@ def _download_and_prepare_as_dataset(self, builder):
5860
splits = builder.as_dataset()
5961
train_ex = list(splits[tfds.Split.TRAIN])[0]
6062
val_ex = list(splits[tfds.Split.VALIDATION])[0]
61-
for ex in [train_ex, val_ex]:
63+
test_ex = list(splits[tfds.Split.TEST])[0]
64+
for ex in [train_ex, val_ex, test_ex]:
6265
# There should be the same number of each of these; a number
6366
# per group of bboxes indicating which frame they correspond to.
6467
self.assertEqual(
@@ -69,10 +72,12 @@ def _download_and_prepare_as_dataset(self, builder):
6972
splits = builder.as_dataset()
7073
train_ex = list(splits[tfds.Split.TRAIN])[0]
7174
val_ex = list(splits[tfds.Split.VALIDATION])[0]
75+
test_ex = list(splits[tfds.Split.TEST])[0]
7276
# NOTE: For real images, this will be a list of potentially a thousand or
7377
# more frames. For testing purposes we load a single dummy 10 X 10 image.
7478
self.assertEqual(train_ex['video'].shape, (1, 28, 42, 3))
7579
self.assertEqual(val_ex['video'].shape, (1, 28, 42, 3))
80+
self.assertEqual(test_ex['video'].shape, (1, 28, 42, 3))
7681

7782

7883
if __name__ == '__main__':

0 commit comments

Comments
 (0)