Skip to content

Commit 8afd547

Browse files
RazaProdigyRazaProdigyNicolasHug
authored
Added Type Check for cocodetection dataset (#8227)
Co-authored-by: RazaProdigy <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 6d64cb3 commit 8afd547

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

test/test_datasets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,11 @@ def test_transforms_v2_wrapper_spawn(self):
827827
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
828828
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
829829

830+
def test_slice_error(self):
831+
with self.create_dataset() as (dataset, _):
832+
with pytest.raises(ValueError, match="Index must be of type integer"):
833+
dataset[:2]
834+
830835

831836
class CocoCaptionsTestCase(CocoDetectionTestCase):
832837
DATASET_CLASS = datasets.CocoCaptions

torchvision/datasets/coco.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def _load_target(self, id: int) -> List[Any]:
4444
return self.coco.loadAnns(self.coco.getAnnIds(id))
4545

4646
def __getitem__(self, index: int) -> Tuple[Any, Any]:
47+
48+
if not isinstance(index, int):
49+
raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
50+
4751
id = self.ids[index]
4852
image = self._load_image(id)
4953
target = self._load_target(id)

0 commit comments

Comments
 (0)