diff --git a/test/test_datasets.py b/test/test_datasets.py index 8d4eca688a2..feaabd7acd2 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -11,6 +11,7 @@ import re import shutil import string +import sys import unittest import xml.etree.ElementTree as ET import zipfile @@ -1146,6 +1147,7 @@ class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Omniglot ADDITIONAL_CONFIGS = combinations_grid(background=(True, False)) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): target_folder = ( @@ -1902,6 +1904,7 @@ def test_class_to_idx(self): assert dataset.class_to_idx == class_to_idx +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.INaturalist FEATURE_TYPES = (PIL.Image.Image, (int, tuple)) @@ -1910,6 +1913,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]), version=("2021_train",), ) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): categories = [ diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index e041d41f4a2..8713bc041db 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -62,6 +62,9 @@ class INaturalist(VisionDataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -72,6 +75,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Optional[Callable[[Union[str, Path]], Any]] = None, ) -> None: self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) @@ -109,6 +113,8 @@ def __init__( for fname in files: self.index.append((dir_index, fname)) + self.loader = loader or Image.open + def _init_2021(self) -> None: """Initialize based on 2021 layout""" @@ -178,7 +184,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ cat_id, fname = self.index[index] - img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname)) + img = self.loader(os.path.join(self.root, self.all_categories[cat_id], fname)) target: Any = [] for t in self.target_type: diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index c3434a72456..f8d182cdb25 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -23,6 +23,9 @@ class Omniglot(VisionDataset): download (bool, optional): If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ folder = "omniglot-py" @@ -39,6 +42,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Optional[Callable[[Union[str, Path]], Any]] = None, ) -> None: super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform) self.background = background @@ -59,6 +63,7 @@ def __init__( for idx, character in enumerate(self._characters) ] self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) + self.loader = loader def __len__(self) -> int: return len(self._flat_character_images) @@ -73,7 +78,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ image_name, character_class = self._flat_character_images[index] image_path = join(self.target_folder, self._characters[character_class], image_name) - image = Image.open(image_path, mode="r").convert("L") + image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path) if self.transform: image = self.transform(image)