Skip to content

Commit 83c6361

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into extra_decoders
2 parents 8bbebcb + 66c5629 commit 83c6361

29 files changed

+112
-98
lines changed

.github/workflows/update-viablestrict.yml

Lines changed: 0 additions & 24 deletions
This file was deleted.

docs/source/datasets.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ All the datasets have almost similar API. They all have two common arguments:
2727
``transform`` and ``target_transform`` to transform the input and target respectively.
2828
You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`.
2929

30+
.. warning::
31+
32+
When a dataset object is created with ``download=True``, the files are first
33+
downloaded and extracted in the root directory. This download logic is not
34+
multi-process safe, so it may lead to conflicts / race conditions if it is
35+
run within a distributed setting. In distributed mode, we recommend creating
36+
a dummy dataset object to trigger the download logic *before* setting up
37+
distributed mode.
38+
3039
Image classification
3140
~~~~~~~~~~~~~~~~~~~~
3241

setup.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
4343
BUILD_CUDA_SOURCES = (torch.cuda.is_available() and ((CUDA_HOME is not None) or IS_ROCM)) or FORCE_CUDA
4444

45-
PACKAGE_NAME = "torchvision"
45+
package_name = os.getenv("TORCHVISION_PACKAGE_NAME", "torchvision")
4646

4747
print("Torchvision build configuration:")
4848
print(f"{FORCE_CUDA = }")
@@ -98,7 +98,7 @@ def get_dist(pkgname):
9898
except DistributionNotFound:
9999
return None
100100

101-
pytorch_dep = "torch"
101+
pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch")
102102
if os.getenv("PYTORCH_VERSION"):
103103
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
104104

@@ -127,7 +127,7 @@ def get_macros_and_flags():
127127
if NVCC_FLAGS is None:
128128
nvcc_flags = []
129129
else:
130-
nvcc_flags = nvcc_flags.split(" ")
130+
nvcc_flags = NVCC_FLAGS.split(" ")
131131
extra_compile_args["nvcc"] = nvcc_flags
132132

133133
if sys.platform == "win32":
@@ -561,7 +561,7 @@ def run(self):
561561
version, sha = get_version()
562562
write_version_file(version, sha)
563563

564-
print(f"Building wheel {PACKAGE_NAME}-{version}")
564+
print(f"Building wheel {package_name}-{version}")
565565

566566
with open("README.md") as f:
567567
readme = f.read()
@@ -573,7 +573,7 @@ def run(self):
573573
]
574574

575575
setup(
576-
name=PACKAGE_NAME,
576+
name=package_name,
577577
version=version,
578578
author="PyTorch Core Team",
579579
author_email="[email protected]",
@@ -583,7 +583,7 @@ def run(self):
583583
long_description_content_type="text/markdown",
584584
license="BSD",
585585
packages=find_packages(exclude=("test",)),
586-
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so", "prototype/datasets/_builtin/*.categories"]},
586+
package_data={package_name: ["*.dll", "*.dylib", "*.so", "prototype/datasets/_builtin/*.categories"]},
587587
zip_safe=False,
588588
install_requires=get_requirements(),
589589
extras_require={

test/test_datasets.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,11 +1869,6 @@ def test_class_to_idx(self):
18691869
with self.create_dataset() as (dataset, _):
18701870
assert dataset.class_to_idx == class_to_idx
18711871

1872-
def test_images_download_preexisting(self):
1873-
with pytest.raises(RuntimeError):
1874-
with self.create_dataset({"download": True}):
1875-
pass
1876-
18771872

18781873
class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
18791874
DATASET_CLASS = datasets.INaturalist

test/test_transforms_v2.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality):
61696169
def test_transform_invalid_quality_error(self, quality):
61706170
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
61716171
transforms.JPEG(quality=quality)
6172+
6173+
6174+
class TestUtils:
6175+
# TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
6176+
@pytest.mark.parametrize(
6177+
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6178+
)
6179+
@pytest.mark.parametrize(
6180+
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6181+
)
6182+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6183+
def test_query_size_and_query_chw(self, make_input1, make_input2, query):
6184+
size = (32, 64)
6185+
input1 = make_input1(size)
6186+
input2 = make_input2(size)
6187+
6188+
if query is transforms.query_chw and not any(
6189+
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
6190+
for inpt in (input1, input2)
6191+
):
6192+
return
6193+
6194+
expected = size if query is transforms.query_size else ((3,) + size)
6195+
assert query([input1, input2]) == expected
6196+
6197+
@pytest.mark.parametrize(
6198+
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6199+
)
6200+
@pytest.mark.parametrize(
6201+
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6202+
)
6203+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6204+
def test_different_sizes(self, make_input1, make_input2, query):
6205+
input1 = make_input1((10, 10))
6206+
input2 = make_input2((20, 20))
6207+
if query is transforms.query_chw and not all(
6208+
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
6209+
for inpt in (input1, input2)
6210+
):
6211+
return
6212+
with pytest.raises(ValueError, match="Found multiple"):
6213+
query([input1, input2])
6214+
6215+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6216+
def test_no_valid_input(self, query):
6217+
with pytest.raises(TypeError, match="No image"):
6218+
query(["blah"])

torchvision/csrc/io/image/cpu/decode_webp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ torch::Tensor decode_webp(
4444

4545
auto decoded_data =
4646
decoding_func(encoded_data_p, encoded_data_size, &width, &height);
47+
4748
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");
4849

50+
auto deleter = [decoded_data](void*) { WebPFree(decoded_data); };
4951
auto out = torch::from_blob(
50-
decoded_data, {height, width, num_channels}, torch::kUInt8);
52+
decoded_data, {height, width, num_channels}, deleter, torch::kUInt8);
5153

5254
return out.permute({2, 0, 1});
5355
}

torchvision/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"QMNIST",
7373
"MNIST",
7474
"KMNIST",
75+
"MovingMNIST",
7576
"StanfordCars",
7677
"STL10",
7778
"SUN397",

torchvision/datasets/_stereo_matching.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,6 @@ def _download_dataset(self, root: Union[str, Path]) -> None:
588588
for calibration in ["perfect", "imperfect"]:
589589
scene_name = f"{split_scene}-{calibration}"
590590
scene_url = f"{base_url}/{scene_name}.zip"
591-
print(f"Downloading {scene_url}")
592591
# download the scene only if it doesn't exist
593592
if not (split_root / scene_name).exists():
594593
download_and_extract_archive(

torchvision/datasets/caltech.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def __len__(self) -> int:
130130

131131
def download(self) -> None:
132132
if self._check_integrity():
133-
print("Files already downloaded and verified")
134133
return
135134

136135
download_and_extract_archive(
@@ -231,7 +230,6 @@ def __len__(self) -> int:
231230

232231
def download(self) -> None:
233232
if self._check_integrity():
234-
print("Files already downloaded and verified")
235233
return
236234

237235
download_and_extract_archive(

torchvision/datasets/celeba.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
if mask == slice(None): # if split == "all"
106106
self.filename = splits.index
107107
else:
108-
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
108+
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] # type: ignore[arg-type]
109109
self.identity = identity.data[mask]
110110
self.bbox = bbox.data[mask]
111111
self.landmarks_align = landmarks_align.data[mask]
@@ -148,7 +148,6 @@ def _check_integrity(self) -> bool:
148148

149149
def download(self) -> None:
150150
if self._check_integrity():
151-
print("Files already downloaded and verified")
152151
return
153152

154153
for (file_id, md5, filename) in self.file_list:

0 commit comments

Comments
 (0)