Skip to content

Commit 1c4636d

Browse files
committed
fix: fix lsun dataset error message.
1 parent b818d32 commit 1c4636d

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

test/test_datasets.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,50 @@ def test_not_found_or_corrupted(self):
10561056
with pytest.raises(datasets_utils.lazy_importer.lmdb.Error):
10571057
super().test_not_found_or_corrupted()
10581058

1059+
def test_class_name_verification(self):
1060+
err_msg = (
1061+
"Unknown value '{}' for LSUN class. "
1062+
"The valid value is one of 'train', 'val' or 'test' or a list of categories "
1063+
"e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', "
1064+
"'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', "
1065+
"'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', "
1066+
"'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', "
1067+
"'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']."
1068+
)
1069+
1070+
cases = [
1071+
"bedroom",
1072+
"bedroom_train",
1073+
]
1074+
for case in cases:
1075+
with pytest.raises(
1076+
ValueError,
1077+
match=re.escape(err_msg.format(case)),
1078+
):
1079+
with self.create_dataset(classes=case):
1080+
pass
1081+
1082+
for case in [
1083+
["bedroom_train", "bedroom"],
1084+
["bedroom_train", "bedroommmmmmmm_val"],
1085+
]:
1086+
with pytest.raises(
1087+
ValueError,
1088+
match=re.escape(err_msg.format(case[-1])),
1089+
):
1090+
with self.create_dataset(classes=case):
1091+
pass
1092+
1093+
for case in [[None], [1]]:
1094+
with pytest.raises(
1095+
TypeError,
1096+
match=re.escape(
1097+
f"Expected type str for elements in argument classes, but got type {type(case[0]).__name__}."
1098+
),
1099+
):
1100+
with self.create_dataset(classes=case):
1101+
pass
1102+
10591103

10601104
class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
10611105
DATASET_CLASS = datasets.Kinetics

torchvision/datasets/lsun.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import string
55
from collections.abc import Iterable
66
from pathlib import Path
7-
from typing import Any, Callable, cast, Optional, Union
7+
from typing import Any, Callable, Optional, Union
88

99
from PIL import Image
1010

11-
from .utils import iterable_to_str, verify_str_arg
11+
from .utils import verify_str_arg
1212
from .vision import VisionDataset
1313

1414

@@ -108,31 +108,45 @@ def _verify_classes(self, classes: Union[str, list[str]]) -> list[str]:
108108
]
109109
dset_opts = ["train", "val", "test"]
110110

111-
try:
112-
classes = cast(str, classes)
113-
verify_str_arg(classes, "classes", dset_opts)
111+
err_msg = (
112+
"Unknown value '{classes}' for LSUN class. "
113+
"The valid value is one of 'train', 'val' or 'test' or a list of categories "
114+
"e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', "
115+
"'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', "
116+
"'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', "
117+
"'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', "
118+
"'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']."
119+
)
120+
121+
if isinstance(classes, str):
122+
if classes not in dset_opts:
123+
raise ValueError(err_msg.format(classes=classes))
124+
# If classes is a string, it should be one of the dataset options
125+
# and not a specific category.
114126
if classes == "test":
115127
classes = [classes]
116128
else:
117129
classes = [c + "_" + classes for c in categories]
118-
except ValueError:
119-
if not isinstance(classes, Iterable):
120-
msg = "Expected type str or Iterable for argument classes, but got type {}."
121-
raise ValueError(msg.format(type(classes)))
122-
130+
elif isinstance(classes, Iterable):
123131
classes = list(classes)
124132
msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
133+
125134
for c in classes:
126-
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
135+
if not isinstance(c, str):
136+
raise TypeError(msg_fmtstr_type.format(type(c).__name__))
137+
msg = err_msg.format(classes=c)
138+
127139
c_short = c.split("_")
140+
if len(c_short) < 2:
141+
raise ValueError(msg)
128142
category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
129143

130-
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
131-
msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
132144
verify_str_arg(category, valid_values=categories, custom_msg=msg)
133-
134-
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
135145
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
146+
else:
147+
raise TypeError(
148+
f"Expected type str or Iterable for argument classes, but got type {type(classes).__name__}."
149+
)
136150

137151
return classes
138152

0 commit comments

Comments
 (0)