Skip to content

Commit 3c06489

Browse files
authored
Merge branch 'main' into fix/celeb_a_split
2 parents 39042fc + b199170 commit 3c06489

File tree

4 files changed

+114
-6
lines changed

4 files changed

+114
-6
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_version():
7979

8080
def write_version_file(version, sha):
8181
# Exists for BC, probably completely useless.
82-
with open(ROOT_DIR / "torchvision/version.py", "w") as f:
82+
with open(ROOT_DIR / "torchvision" / "version.py", "w") as f:
8383
f.write(f"__version__ = '{version}'\n")
8484
f.write(f"git_version = {repr(sha)}\n")
8585
f.write("from torchvision.extension import _check_cuda_version\n")
@@ -194,7 +194,7 @@ def make_C_extension():
194194

195195
def find_libpng():
196196
# Returns (found, include dir, library dir, library name)
197-
if sys.platform in ("linux", "darwin"):
197+
if sys.platform in ("linux", "darwin", "aix"):
198198
libpng_config = shutil.which("libpng-config")
199199
if libpng_config is None:
200200
warnings.warn("libpng-config not found")

torchvision/datasets/flowers102.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,108 @@ def download(self):
112112
for id in ["label", "setid"]:
113113
filename, md5 = self._file_dict[id]
114114
download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
115+
116+
classes = [
117+
"pink primrose",
118+
"hard-leaved pocket orchid",
119+
"canterbury bells",
120+
"sweet pea",
121+
"english marigold",
122+
"tiger lily",
123+
"moon orchid",
124+
"bird of paradise",
125+
"monkshood",
126+
"globe thistle",
127+
"snapdragon",
128+
"colt's foot",
129+
"king protea",
130+
"spear thistle",
131+
"yellow iris",
132+
"globe-flower",
133+
"purple coneflower",
134+
"peruvian lily",
135+
"balloon flower",
136+
"giant white arum lily",
137+
"fire lily",
138+
"pincushion flower",
139+
"fritillary",
140+
"red ginger",
141+
"grape hyacinth",
142+
"corn poppy",
143+
"prince of wales feathers",
144+
"stemless gentian",
145+
"artichoke",
146+
"sweet william",
147+
"carnation",
148+
"garden phlox",
149+
"love in the mist",
150+
"mexican aster",
151+
"alpine sea holly",
152+
"ruby-lipped cattleya",
153+
"cape flower",
154+
"great masterwort",
155+
"siam tulip",
156+
"lenten rose",
157+
"barbeton daisy",
158+
"daffodil",
159+
"sword lily",
160+
"poinsettia",
161+
"bolero deep blue",
162+
"wallflower",
163+
"marigold",
164+
"buttercup",
165+
"oxeye daisy",
166+
"common dandelion",
167+
"petunia",
168+
"wild pansy",
169+
"primula",
170+
"sunflower",
171+
"pelargonium",
172+
"bishop of llandaff",
173+
"gaura",
174+
"geranium",
175+
"orange dahlia",
176+
"pink-yellow dahlia?",
177+
"cautleya spicata",
178+
"japanese anemone",
179+
"black-eyed susan",
180+
"silverbush",
181+
"californian poppy",
182+
"osteospermum",
183+
"spring crocus",
184+
"bearded iris",
185+
"windflower",
186+
"tree poppy",
187+
"gazania",
188+
"azalea",
189+
"water lily",
190+
"rose",
191+
"thorn apple",
192+
"morning glory",
193+
"passion flower",
194+
"lotus",
195+
"toad lily",
196+
"anthurium",
197+
"frangipani",
198+
"clematis",
199+
"hibiscus",
200+
"columbine",
201+
"desert-rose",
202+
"tree mallow",
203+
"magnolia",
204+
"cyclamen",
205+
"watercress",
206+
"canna lily",
207+
"hippeastrum",
208+
"bee balm",
209+
"ball moss",
210+
"foxglove",
211+
"bougainvillea",
212+
"camellia",
213+
"mallow",
214+
"mexican petunia",
215+
"bromelia",
216+
"blanket flower",
217+
"trumpet creeper",
218+
"blackberry lily",
219+
]

torchvision/datasets/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class MNIST(VisionDataset):
3535
"""
3636

3737
mirrors = [
38-
"http://yann.lecun.com/exdb/mnist/",
3938
"https://ossci-datasets.s3.amazonaws.com/mnist/",
39+
"http://yann.lecun.com/exdb/mnist/",
4040
]
4141

4242
resources = [
@@ -514,7 +514,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
514514
data = f.read()
515515

516516
# parse
517-
if sys.byteorder == "little":
517+
if sys.byteorder == "little" or sys.platform == "aix":
518518
magic = get_int(data[0:4])
519519
nd = magic % 256
520520
ty = magic // 256
@@ -527,7 +527,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
527527
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
528528
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
529529

530-
if sys.byteorder == "big":
530+
if sys.byteorder == "big" and not sys.platform == "aix":
531531
for i in range(len(s)):
532532
s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
533533

torchvision/ops/focal_loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def sigmoid_focal_loss(
2020
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
2121
classification label for each element in inputs
2222
(0 for the negative class and 1 for the positive class).
23-
alpha (float): Weighting factor in range (0,1) to balance
23+
alpha (float): Weighting factor in range [0, 1] to balance
2424
positive vs negative examples or -1 for ignore. Default: ``0.25``.
2525
gamma (float): Exponent of the modulating factor (1 - p_t) to
2626
balance easy vs hard examples. Default: ``2``.
@@ -33,6 +33,9 @@ def sigmoid_focal_loss(
3333
"""
3434
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
3535

36+
if not (0 <= alpha <= 1) or alpha != -1:
37+
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
38+
3639
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
3740
_log_api_usage_once(sigmoid_focal_loss)
3841
p = torch.sigmoid(inputs)

0 commit comments

Comments
 (0)