From 551863e664200e00a884e9b198175cf20c4f0194 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 20 Feb 2025 21:46:23 +0800 Subject: [PATCH 1/3] feat: expose loader argument in Country211 and EuroSAT. --- torchvision/datasets/country211.py | 13 ++++++++++--- torchvision/datasets/eurosat.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index a0f82ee1226..d3383dc97b6 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -1,7 +1,7 @@ from pathlib import Path -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from .folder import ImageFolder +from .folder import default_loader, ImageFolder from .utils import download_and_extract_archive, verify_str_arg @@ -21,6 +21,7 @@ class Country211(ImageFolder): target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/country211/``. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. """ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz" @@ -33,6 +34,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: self._split = verify_str_arg(split, "split", ("train", "valid", "test")) @@ -46,7 +48,12 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform) + super().__init__( + str(self._base_folder / self._split), + transform=transform, + target_transform=target_transform, + loader=loader, + ) self.root = str(root) def _check_exists(self) -> bool: diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index c6571d2abab..0d3dd2c1fad 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -1,8 +1,8 @@ import os from pathlib import Path -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from .folder import ImageFolder +from .folder import default_loader, ImageFolder from .utils import download_and_extract_archive @@ -21,6 +21,7 @@ class EuroSAT(ImageFolder): download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. + loader (callable, optional): A function to load an image given its path. """ def __init__( @@ -29,6 +30,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Callable[[str], Any] = default_loader, ) -> None: self.root = os.path.expanduser(root) self._base_folder = os.path.join(self.root, "eurosat") @@ -40,7 +42,12 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - super().__init__(self._data_folder, transform=transform, target_transform=target_transform) + super().__init__( + self._data_folder, + transform=transform, + target_transform=target_transform, + loader=loader, + ) self.root = os.path.expanduser(root) def __len__(self) -> int: From 8ed594755fd0f6788ddf4207baba75126fe1923f Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 21 Feb 2025 21:57:50 +0800 Subject: [PATCH 2/3] docs: update loader argument docstring. --- torchvision/datasets/country211.py | 2 ++ torchvision/datasets/eurosat.py | 2 ++ torchvision/datasets/imagenet.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index d3383dc97b6..26b49552771 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -22,6 +22,8 @@ class Country211(ImageFolder): download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/country211/``. If dataset is already downloaded, it is not downloaded again. loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz" diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 0d3dd2c1fad..5b96b067fba 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -22,6 +22,8 @@ class EuroSAT(ImageFolder): puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default is False. loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index d7caf328d2b..2d7e1e2f4d7 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -36,6 +36,8 @@ class ImageNet(ImageFolder): target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. Attributes: classes (list): List of the class name tuples. From b4877bc584cce7af960c1dd3293f41d8af2bfee2 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 22 Feb 2025 13:59:07 +0800 Subject: [PATCH 3/3] test: add `test_tv_decode_image_support` to check with the image output type. --- test/datasets_utils.py | 23 ++++++++++++++++++----- test/test_datasets.py | 5 +++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 43b4103646a..6a552a96923 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase): """ FEATURE_TYPES = (PIL.Image.Image, int) + SUPPORT_TV_IMAGE_DECODE: bool = False @contextlib.contextmanager def create_dataset( @@ -632,22 +633,34 @@ def create_dataset( # This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an # image, but never use the underlying data. During normal operation it is reasonable to assume that the # user wants to work with the image he just opened rather than deleting the underlying file. - with self._force_load_images(): + with self._force_load_images(loader=(config or {}).get("loader", None)): yield dataset, info @contextlib.contextmanager - def _force_load_images(self): - open = PIL.Image.open + def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None): + open = loader or PIL.Image.open def new(fp, *args, **kwargs): image = open(fp, *args, **kwargs) - if isinstance(fp, (str, pathlib.Path)): + if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image): image.load() return image - with unittest.mock.patch("PIL.Image.open", new=new): + with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new): yield + def test_tv_decode_image_support(self): + if not self.SUPPORT_TV_IMAGE_DECODE: + pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.") + + with self.create_dataset( + config=dict( + loader=torchvision.io.decode_image, + ) + ) as (dataset, _): + image = dataset[0][0] + assert isinstance(image, torch.Tensor) + class VideoDatasetTestCase(DatasetTestCase): """Abstract base class for video dataset testcases. diff --git a/test/test_datasets.py b/test/test_datasets.py index 1c1d05ac42a..f98a18372a5 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -405,6 +405,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): REQUIRED_PACKAGES = ("scipy",) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val")) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) @@ -2308,6 +2310,7 @@ def inject_fake_data(self, tmpdir, config): class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.EuroSAT FEATURE_TYPES = (PIL.Image.Image, int) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): data_folder = os.path.join(tmpdir, "eurosat", "2750") @@ -2749,6 +2752,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test")) + SUPPORT_TV_IMAGE_DECODE = True + def inject_fake_data(self, tmpdir: str, config): split_folder = pathlib.Path(tmpdir) / "country211" / config["split"] split_folder.mkdir(parents=True, exist_ok=True)