Skip to content

Commit 47c1912

Browse files
author
pytorchbot
committed
2025-02-20 nightly release (b5c7443)
1 parent 4219bf9 commit 47c1912

File tree

8 files changed

+59
-26
lines changed

8 files changed

+59
-26
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_version():
7979

8080
def write_version_file(version, sha):
8181
# Exists for BC, probably completely useless.
82-
with open(ROOT_DIR / "torchvision/version.py", "w") as f:
82+
with open(ROOT_DIR / "torchvision" / "version.py", "w") as f:
8383
f.write(f"__version__ = '{version}'\n")
8484
f.write(f"git_version = {repr(sha)}\n")
8585
f.write("from torchvision.extension import _check_cuda_version\n")
@@ -194,7 +194,7 @@ def make_C_extension():
194194

195195
def find_libpng():
196196
# Returns (found, include dir, library dir, library name)
197-
if sys.platform in ("linux", "darwin"):
197+
if sys.platform in ("linux", "darwin", "aix"):
198198
libpng_config = shutil.which("libpng-config")
199199
if libpng_config is None:
200200
warnings.warn("libpng-config not found")

test/test_datasets.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532532
self._create_bbox_txt(base_folder, num_images)
533533
self._create_landmarks_txt(base_folder, num_images)
534534

535-
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
535+
num_samples = num_images_per_split.get(config["split"], 0) if isinstance(config["split"], str) else 0
536+
return dict(num_examples=num_samples, attr_names=attr_names)
536537

537538
def _create_split_txt(self, root):
538539
num_images_per_split = dict(train=4, valid=3, test=2)
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635636
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
636637
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
637638

639+
def test_invalid_split_list(self):
640+
with pytest.raises(ValueError, match="Expected type str for argument split, but got type <class 'list'>."):
641+
with self.create_dataset(split=[1]):
642+
pass
643+
644+
def test_invalid_split_int(self):
645+
with pytest.raises(ValueError, match="Expected type str for argument split, but got type <class 'int'>."):
646+
with self.create_dataset(split=1):
647+
pass
648+
649+
def test_invalid_split_value(self):
650+
with pytest.raises(
651+
ValueError,
652+
match="Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}.".format(
653+
value="invalid",
654+
arg="split",
655+
valid_values=("train", "valid", "test", "all"),
656+
),
657+
):
658+
with self.create_dataset(split="invalid"):
659+
pass
660+
638661

639662
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
640663
DATASET_CLASS = datasets.VOCSegmentation

torchvision/datasets/celeba.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ def __init__(
9393
"test": 2,
9494
"all": None,
9595
}
96-
split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
96+
split_ = split_map[
97+
verify_str_arg(
98+
split.lower() if isinstance(split, str) else split,
99+
"split",
100+
("train", "valid", "test", "all"),
101+
)
102+
]
97103
splits = self._load_csv("list_eval_partition.txt")
98104
identity = self._load_csv("identity_CelebA.txt")
99105
bbox = self._load_csv("list_bbox_celeba.txt", header=1)

torchvision/datasets/mnist.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@ class MNIST(VisionDataset):
2525
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
2626
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
2727
otherwise from ``t10k-images-idx3-ubyte``.
28-
download (bool, optional): If True, downloads the dataset from the internet and
29-
puts it in root directory. If dataset is already downloaded, it is not
30-
downloaded again.
3128
transform (callable, optional): A function/transform that takes in a PIL image
3229
and returns a transformed version. E.g, ``transforms.RandomCrop``
3330
target_transform (callable, optional): A function/transform that takes in the
3431
target and transforms it.
32+
download (bool, optional): If True, downloads the dataset from the internet and
33+
puts it in root directory. If dataset is already downloaded, it is not
34+
downloaded again.
3535
"""
3636

3737
mirrors = [
38-
"http://yann.lecun.com/exdb/mnist/",
3938
"https://ossci-datasets.s3.amazonaws.com/mnist/",
39+
"http://yann.lecun.com/exdb/mnist/",
4040
]
4141

4242
resources = [
@@ -209,13 +209,13 @@ class FashionMNIST(MNIST):
209209
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
210210
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
211211
otherwise from ``t10k-images-idx3-ubyte``.
212-
download (bool, optional): If True, downloads the dataset from the internet and
213-
puts it in root directory. If dataset is already downloaded, it is not
214-
downloaded again.
215212
transform (callable, optional): A function/transform that takes in a PIL image
216213
and returns a transformed version. E.g, ``transforms.RandomCrop``
217214
target_transform (callable, optional): A function/transform that takes in the
218215
target and transforms it.
216+
download (bool, optional): If True, downloads the dataset from the internet and
217+
puts it in root directory. If dataset is already downloaded, it is not
218+
downloaded again.
219219
"""
220220

221221
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
@@ -237,13 +237,13 @@ class KMNIST(MNIST):
237237
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
238238
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
239239
otherwise from ``t10k-images-idx3-ubyte``.
240-
download (bool, optional): If True, downloads the dataset from the internet and
241-
puts it in root directory. If dataset is already downloaded, it is not
242-
downloaded again.
243240
transform (callable, optional): A function/transform that takes in a PIL image
244241
and returns a transformed version. E.g, ``transforms.RandomCrop``
245242
target_transform (callable, optional): A function/transform that takes in the
246243
target and transforms it.
244+
download (bool, optional): If True, downloads the dataset from the internet and
245+
puts it in root directory. If dataset is already downloaded, it is not
246+
downloaded again.
247247
"""
248248

249249
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
@@ -358,6 +358,9 @@ class QMNIST(MNIST):
358358
for each example is class number (for compatibility with
359359
the MNIST dataloader) or a torch vector containing the
360360
full qmnist information. Default=True.
361+
train (bool,optional,compatibility): When argument 'what' is
362+
not specified, this boolean decides whether to load the
363+
training set or the testing set. Default: True.
361364
download (bool, optional): If True, downloads the dataset from
362365
the internet and puts it in root directory. If dataset is
363366
already downloaded, it is not downloaded again.
@@ -366,9 +369,6 @@ class QMNIST(MNIST):
366369
version. E.g, ``transforms.RandomCrop``
367370
target_transform (callable, optional): A function/transform
368371
that takes in the target and transforms it.
369-
train (bool,optional,compatibility): When argument 'what' is
370-
not specified, this boolean decides whether to load the
371-
training set or the testing set. Default: True.
372372
"""
373373

374374
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
@@ -514,7 +514,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
514514
data = f.read()
515515

516516
# parse
517-
if sys.byteorder == "little":
517+
if sys.byteorder == "little" or sys.platform == "aix":
518518
magic = get_int(data[0:4])
519519
nd = magic % 256
520520
ty = magic // 256
@@ -527,7 +527,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
527527
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
528528
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
529529

530-
if sys.byteorder == "big":
530+
if sys.byteorder == "big" and not sys.platform == "aix":
531531
for i in range(len(s)):
532532
s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
533533

torchvision/datasets/moving_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ class MovingMNIST(VisionDataset):
1818
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
1919
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
2020
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
21-
transform (callable, optional): A function/transform that takes in a torch Tensor
22-
and returns a transformed version. E.g, ``transforms.RandomCrop``
2321
download (bool, optional): If true, downloads the dataset from the internet and
2422
puts it in root directory. If dataset is already downloaded, it is not
2523
downloaded again.
24+
transform (callable, optional): A function/transform that takes in a torch Tensor
25+
and returns a transformed version. E.g, ``transforms.RandomCrop``
2626
"""
2727

2828
_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"

torchvision/datasets/oxford_iiit_pet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class OxfordIIITPet(VisionDataset):
2727
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2828
version. E.g, ``transforms.RandomCrop``.
2929
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
transforms (callable, optional): A function/transform that takes input sample
31+
and its target as entry and returns a transformed version.
3032
download (bool, optional): If True, downloads the dataset from the internet and puts it into
3133
``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
3234
"""

torchvision/models/_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial
88
from inspect import signature
99
from types import ModuleType
10-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
10+
from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
1111

1212
from torch import nn
1313

@@ -168,14 +168,13 @@ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
168168
if "weights" not in sig.parameters:
169169
raise ValueError("The method is missing the 'weights' argument.")
170170

171-
ann = signature(fn).parameters["weights"].annotation
171+
ann = sig.parameters["weights"].annotation
172172
weights_enum = None
173173
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
174174
weights_enum = ann
175175
else:
176176
# handle cases like Union[Optional, T]
177-
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
178-
for t in ann.__args__: # type: ignore[union-attr]
177+
for t in get_args(ann): # type: ignore[union-attr]
179178
if isinstance(t, type) and issubclass(t, WeightsEnum):
180179
weights_enum = t
181180
break

torchvision/ops/focal_loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def sigmoid_focal_loss(
2020
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
2121
classification label for each element in inputs
2222
(0 for the negative class and 1 for the positive class).
23-
alpha (float): Weighting factor in range (0,1) to balance
23+
alpha (float): Weighting factor in range [0, 1] to balance
2424
positive vs negative examples or -1 for ignore. Default: ``0.25``.
2525
gamma (float): Exponent of the modulating factor (1 - p_t) to
2626
balance easy vs hard examples. Default: ``2``.
@@ -33,6 +33,9 @@ def sigmoid_focal_loss(
3333
"""
3434
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
3535

36+
if not (0 <= alpha <= 1) or alpha != -1:
37+
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
38+
3639
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
3740
_log_api_usage_once(sigmoid_focal_loss)
3841
p = torch.sigmoid(inputs)

0 commit comments

Comments
 (0)