Skip to content

Commit 3698e04

Browse files
Merge branch 'arm64_enablement' of https://github.com/alinpahontu2912/vision into arm64_enablement
2 parents 839d069 + c72d258 commit 3698e04

24 files changed

+259
-101
lines changed

setup.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import distutils.spawn
33
import glob
44
import os
5+
import shlex
56
import shutil
67
import subprocess
78
import sys
@@ -95,8 +96,14 @@ def get_dist(pkgname):
9596
return None
9697

9798
pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch")
98-
if os.getenv("PYTORCH_VERSION"):
99-
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
99+
if version_pin := os.getenv("PYTORCH_VERSION"):
100+
pytorch_dep += "==" + version_pin
101+
elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")):
102+
# This branch and the associated env vars exist to help third-party
103+
# builds like in https://github.com/pytorch/vision/pull/8936. This is
104+
# supported on a best-effort basis, we don't guarantee that this won't
105+
# eventually break (and we don't test it.)
106+
pytorch_dep += f">={version_pin_ge},<{version_pin_lt}"
100107

101108
requirements = [
102109
"numpy",
@@ -123,7 +130,7 @@ def get_macros_and_flags():
123130
if NVCC_FLAGS is None:
124131
nvcc_flags = []
125132
else:
126-
nvcc_flags = NVCC_FLAGS.split(" ")
133+
nvcc_flags = shlex.split(NVCC_FLAGS)
127134
extra_compile_args["nvcc"] = nvcc_flags
128135

129136
if sys.platform == "win32":

test/datasets_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase):
611611
"""
612612

613613
FEATURE_TYPES = (PIL.Image.Image, int)
614+
SUPPORT_TV_IMAGE_DECODE: bool = False
614615

615616
@contextlib.contextmanager
616617
def create_dataset(
@@ -632,22 +633,34 @@ def create_dataset(
632633
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
633634
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
634635
# user wants to work with the image he just opened rather than deleting the underlying file.
635-
with self._force_load_images():
636+
with self._force_load_images(loader=(config or {}).get("loader", None)):
636637
yield dataset, info
637638

638639
@contextlib.contextmanager
639-
def _force_load_images(self):
640-
open = PIL.Image.open
640+
def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None):
641+
open = loader or PIL.Image.open
641642

642643
def new(fp, *args, **kwargs):
643644
image = open(fp, *args, **kwargs)
644-
if isinstance(fp, (str, pathlib.Path)):
645+
if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image):
645646
image.load()
646647
return image
647648

648-
with unittest.mock.patch("PIL.Image.open", new=new):
649+
with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new):
649650
yield
650651

652+
def test_tv_decode_image_support(self):
653+
if not self.SUPPORT_TV_IMAGE_DECODE:
654+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
655+
656+
with self.create_dataset(
657+
config=dict(
658+
loader=torchvision.io.decode_image,
659+
)
660+
) as (dataset, _):
661+
image = dataset[0][0]
662+
assert isinstance(image, torch.Tensor)
663+
651664

652665
class VideoDatasetTestCase(DatasetTestCase):
653666
"""Abstract base class for video dataset testcases.

test/test_datasets.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn.functional as F
2525
from common_utils import combinations_grid
2626
from torchvision import datasets
27+
from torchvision.io import decode_image
2728
from torchvision.transforms import v2
2829

2930

@@ -405,6 +406,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
405406
REQUIRED_PACKAGES = ("scipy",)
406407
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
407408

409+
SUPPORT_TV_IMAGE_DECODE = True
410+
408411
def inject_fake_data(self, tmpdir, config):
409412
tmpdir = pathlib.Path(tmpdir)
410413

@@ -1173,6 +1176,8 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase):
11731176
DATASET_CLASS = datasets.SBU
11741177
FEATURE_TYPES = (PIL.Image.Image, str)
11751178

1179+
SUPPORT_TV_IMAGE_DECODE = True
1180+
11761181
def inject_fake_data(self, tmpdir, config):
11771182
num_images = 3
11781183

@@ -1411,6 +1416,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
14111416
_IMAGES_FOLDER = "images"
14121417
_ANNOTATIONS_FILE = "captions.html"
14131418

1419+
SUPPORT_TV_IMAGE_DECODE = True
1420+
14141421
def dataset_args(self, tmpdir, config):
14151422
tmpdir = pathlib.Path(tmpdir)
14161423
root = tmpdir / self._IMAGES_FOLDER
@@ -1480,6 +1487,8 @@ class Flickr30kTestCase(Flickr8kTestCase):
14801487

14811488
_ANNOTATIONS_FILE = "captions.token"
14821489

1490+
SUPPORT_TV_IMAGE_DECODE = True
1491+
14831492
def _image_file_name(self, idx):
14841493
return f"{idx}.jpg"
14851494

@@ -1940,6 +1949,8 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
19401949
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
19411950
_file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
19421951

1952+
SUPPORT_TV_IMAGE_DECODE = True
1953+
19431954
def inject_fake_data(self, tmpdir, config):
19441955
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
19451956
os.makedirs(tmpdir, exist_ok=True)
@@ -1976,6 +1987,18 @@ def _create_random_id(self):
19761987
part2 = datasets_utils.create_random_string(random.randint(4, 7))
19771988
return f"{part1}_{part2}"
19781989

1990+
def test_tv_decode_image_support(self):
1991+
if not self.SUPPORT_TV_IMAGE_DECODE:
1992+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
1993+
1994+
with self.create_dataset(
1995+
config=dict(
1996+
loader=decode_image,
1997+
)
1998+
) as (dataset, _):
1999+
image = dataset[0][0]
2000+
assert isinstance(image, torch.Tensor)
2001+
19792002

19802003
class LFWPairsTestCase(LFWPeopleTestCase):
19812004
DATASET_CLASS = datasets.LFWPairs
@@ -2308,6 +2331,7 @@ def inject_fake_data(self, tmpdir, config):
23082331
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
23092332
DATASET_CLASS = datasets.EuroSAT
23102333
FEATURE_TYPES = (PIL.Image.Image, int)
2334+
SUPPORT_TV_IMAGE_DECODE = True
23112335

23122336
def inject_fake_data(self, tmpdir, config):
23132337
data_folder = os.path.join(tmpdir, "eurosat", "2750")
@@ -2332,6 +2356,8 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
23322356

23332357
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
23342358

2359+
SUPPORT_TV_IMAGE_DECODE = True
2360+
23352361
def inject_fake_data(self, tmpdir: str, config):
23362362
root_folder = pathlib.Path(tmpdir) / "food-101"
23372363
image_folder = root_folder / "images"
@@ -2368,6 +2394,7 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
23682394
ADDITIONAL_CONFIGS = combinations_grid(
23692395
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
23702396
)
2397+
SUPPORT_TV_IMAGE_DECODE = True
23712398

23722399
def inject_fake_data(self, tmpdir: str, config):
23732400
split = config["split"]
@@ -2417,6 +2444,8 @@ def inject_fake_data(self, tmpdir: str, config):
24172444
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
24182445
DATASET_CLASS = datasets.SUN397
24192446

2447+
SUPPORT_TV_IMAGE_DECODE = True
2448+
24202449
def inject_fake_data(self, tmpdir: str, config):
24212450
data_dir = pathlib.Path(tmpdir) / "SUN397"
24222451
data_dir.mkdir()
@@ -2448,6 +2477,8 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
24482477
DATASET_CLASS = datasets.DTD
24492478
FEATURE_TYPES = (PIL.Image.Image, int)
24502479

2480+
SUPPORT_TV_IMAGE_DECODE = True
2481+
24512482
ADDITIONAL_CONFIGS = combinations_grid(
24522483
split=("train", "test", "val"),
24532484
# There is no need to test the whole matrix here, since each fold is treated exactly the same
@@ -2608,6 +2639,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
26082639
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
26092640

26102641
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
2642+
SUPPORT_TV_IMAGE_DECODE = True
26112643

26122644
def inject_fake_data(self, tmpdir, config):
26132645
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
@@ -2705,6 +2737,8 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
27052737
REQUIRED_PACKAGES = ("scipy",)
27062738
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
27072739

2740+
SUPPORT_TV_IMAGE_DECODE = True
2741+
27082742
def inject_fake_data(self, tmpdir, config):
27092743
import scipy.io as io
27102744
from numpy.core.records import fromarrays
@@ -2749,6 +2783,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):
27492783

27502784
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))
27512785

2786+
SUPPORT_TV_IMAGE_DECODE = True
2787+
27522788
def inject_fake_data(self, tmpdir: str, config):
27532789
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
27542790
split_folder.mkdir(parents=True, exist_ok=True)
@@ -2777,6 +2813,8 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
27772813
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
27782814
REQUIRED_PACKAGES = ("scipy",)
27792815

2816+
SUPPORT_TV_IMAGE_DECODE = True
2817+
27802818
def inject_fake_data(self, tmpdir: str, config):
27812819
base_folder = pathlib.Path(tmpdir) / "flowers-102"
27822820

@@ -2835,6 +2873,8 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
28352873
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
28362874
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
28372875

2876+
SUPPORT_TV_IMAGE_DECODE = True
2877+
28382878
def inject_fake_data(self, tmpdir: str, config):
28392879
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
28402880
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
@@ -3495,6 +3535,8 @@ class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
34953535
DATASET_CLASS = datasets.Imagenette
34963536
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])
34973537

3538+
SUPPORT_TV_IMAGE_DECODE = True
3539+
34983540
_WNIDS = [
34993541
"n01440764",
35003542
"n02102040",

test/test_transforms_v2.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3758,12 +3758,18 @@ def test_transform_errors_warnings(self):
37583758
with pytest.raises(ValueError, match="provide only two dimensions"):
37593759
transforms.RandomResizedCrop(size=(1, 2, 3))
37603760

3761-
with pytest.raises(TypeError, match="Scale should be a sequence"):
3761+
with pytest.raises(TypeError, match="Scale should be a sequence of two floats."):
37623762
transforms.RandomResizedCrop(size=self.INPUT_SIZE, scale=123)
37633763

3764-
with pytest.raises(TypeError, match="Ratio should be a sequence"):
3764+
with pytest.raises(TypeError, match="Ratio should be a sequence of two floats."):
37653765
transforms.RandomResizedCrop(size=self.INPUT_SIZE, ratio=123)
37663766

3767+
with pytest.raises(TypeError, match="Ratio should be a sequence of two floats."):
3768+
transforms.RandomResizedCrop(size=self.INPUT_SIZE, ratio=[1, 2, 3])
3769+
3770+
with pytest.raises(TypeError, match="Scale should be a sequence of two floats."):
3771+
transforms.RandomResizedCrop(size=self.INPUT_SIZE, scale=[1, 2, 3])
3772+
37673773
for param in ["scale", "ratio"]:
37683774
with pytest.warns(match="Scale and ratio should be of kind"):
37693775
transforms.RandomResizedCrop(size=self.INPUT_SIZE, **{param: [1, 0]})

torchvision/datasets/clevr.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, List, Optional, Tuple, Union
44
from urllib.parse import urlparse
55

6-
from PIL import Image
6+
from .folder import default_loader
77

88
from .utils import download_and_extract_archive, verify_str_arg
99
from .vision import VisionDataset
@@ -18,11 +18,14 @@ class CLEVRClassification(VisionDataset):
1818
root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
1919
set to True.
2020
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
21-
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
22-
version. E.g, ``transforms.RandomCrop``
21+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
22+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2323
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
2424
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
2525
dataset is already downloaded, it is not downloaded again.
26+
loader (callable, optional): A function to load an image given its path.
27+
By default, it uses PIL as its image loader, but users could also pass in
28+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2629
"""
2730

2831
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
@@ -35,9 +38,11 @@ def __init__(
3538
transform: Optional[Callable] = None,
3639
target_transform: Optional[Callable] = None,
3740
download: bool = False,
41+
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
3842
) -> None:
3943
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4044
super().__init__(root, transform=transform, target_transform=target_transform)
45+
self.loader = loader
4146
self._base_folder = pathlib.Path(self.root) / "clevr"
4247
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
4348

@@ -65,7 +70,7 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
6570
image_file = self._image_files[idx]
6671
label = self._labels[idx]
6772

68-
image = Image.open(image_file).convert("RGB")
73+
image = self.loader(image_file)
6974

7075
if self.transform:
7176
image = self.transform(image)

torchvision/datasets/coco.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
class CocoDetection(VisionDataset):
1111
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
1212
13-
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
13+
It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
14+
which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
1415
1516
Args:
1617
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
@@ -65,7 +66,8 @@ def __len__(self) -> int:
6566
class CocoCaptions(CocoDetection):
6667
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
6768
68-
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
69+
It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
70+
which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
6971
7072
Args:
7173
root (str or ``pathlib.Path``): Root directory where images are downloaded to.

torchvision/datasets/country211.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
2-
from typing import Callable, Optional, Union
2+
from typing import Any, Callable, Optional, Union
33

4-
from .folder import ImageFolder
4+
from .folder import default_loader, ImageFolder
55
from .utils import download_and_extract_archive, verify_str_arg
66

77

@@ -16,11 +16,14 @@ class Country211(ImageFolder):
1616
Args:
1717
root (str or ``pathlib.Path``): Root directory of the dataset.
1818
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
19-
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
20-
version. E.g, ``transforms.RandomCrop``.
19+
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
20+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2121
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222
download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24+
loader (callable, optional): A function to load an image given its path.
25+
By default, it uses PIL as its image loader, but users could also pass in
26+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2427
"""
2528

2629
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +36,7 @@ def __init__(
3336
transform: Optional[Callable] = None,
3437
target_transform: Optional[Callable] = None,
3538
download: bool = False,
39+
loader: Callable[[str], Any] = default_loader,
3640
) -> None:
3741
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3842

@@ -46,7 +50,12 @@ def __init__(
4650
if not self._check_exists():
4751
raise RuntimeError("Dataset not found. You can use download=True to download it")
4852

49-
super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
53+
super().__init__(
54+
str(self._base_folder / self._split),
55+
transform=transform,
56+
target_transform=target_transform,
57+
loader=loader,
58+
)
5059
self.root = str(root)
5160

5261
def _check_exists(self) -> bool:

0 commit comments

Comments
 (0)