|  | 
| 4 | 4 | import string | 
| 5 | 5 | from collections.abc import Iterable | 
| 6 | 6 | from pathlib import Path | 
| 7 |  | -from typing import Any, Callable, cast, Optional, Union | 
|  | 7 | +from typing import Any, Callable, Optional, Union | 
| 8 | 8 | 
 | 
| 9 | 9 | from PIL import Image | 
| 10 | 10 | 
 | 
| 11 |  | -from .utils import iterable_to_str, verify_str_arg | 
|  | 11 | +from .utils import verify_str_arg | 
| 12 | 12 | from .vision import VisionDataset | 
| 13 | 13 | 
 | 
| 14 | 14 | 
 | 
| @@ -108,31 +108,45 @@ def _verify_classes(self, classes: Union[str, list[str]]) -> list[str]: | 
| 108 | 108 |         ] | 
| 109 | 109 |         dset_opts = ["train", "val", "test"] | 
| 110 | 110 | 
 | 
| 111 |  | -        try: | 
| 112 |  | -            classes = cast(str, classes) | 
| 113 |  | -            verify_str_arg(classes, "classes", dset_opts) | 
|  | 111 | +        err_msg = ( | 
|  | 112 | +            "Unknown value '{classes}' for LSUN class. " | 
|  | 113 | +            "The valid value is one of 'train', 'val' or 'test' or a list of categories " | 
|  | 114 | +            "e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', " | 
|  | 115 | +            "'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', " | 
|  | 116 | +            "'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', " | 
|  | 117 | +            "'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', " | 
|  | 118 | +            "'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']." | 
|  | 119 | +        ) | 
|  | 120 | + | 
|  | 121 | +        if isinstance(classes, str): | 
|  | 122 | +            if classes not in dset_opts: | 
|  | 123 | +                raise ValueError(err_msg.format(classes=classes)) | 
|  | 124 | +            # If classes is a string, it should be one of the dataset options | 
|  | 125 | +            # and not a specific category. | 
| 114 | 126 |             if classes == "test": | 
| 115 | 127 |                 classes = [classes] | 
| 116 | 128 |             else: | 
| 117 | 129 |                 classes = [c + "_" + classes for c in categories] | 
| 118 |  | -        except ValueError: | 
| 119 |  | -            if not isinstance(classes, Iterable): | 
| 120 |  | -                msg = "Expected type str or Iterable for argument classes, but got type {}." | 
| 121 |  | -                raise ValueError(msg.format(type(classes))) | 
| 122 |  | - | 
|  | 130 | +        elif isinstance(classes, Iterable): | 
| 123 | 131 |             classes = list(classes) | 
| 124 | 132 |             msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}." | 
|  | 133 | + | 
| 125 | 134 |             for c in classes: | 
| 126 |  | -                verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c))) | 
|  | 135 | +                if not isinstance(c, str): | 
|  | 136 | +                    raise TypeError(msg_fmtstr_type.format(type(c).__name__)) | 
|  | 137 | +                msg = err_msg.format(classes=c) | 
|  | 138 | + | 
| 127 | 139 |                 c_short = c.split("_") | 
|  | 140 | +                if len(c_short) < 2: | 
|  | 141 | +                    raise ValueError(msg) | 
| 128 | 142 |                 category, dset_opt = "_".join(c_short[:-1]), c_short[-1] | 
| 129 | 143 | 
 | 
| 130 |  | -                msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." | 
| 131 |  | -                msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories)) | 
| 132 | 144 |                 verify_str_arg(category, valid_values=categories, custom_msg=msg) | 
| 133 |  | - | 
| 134 |  | -                msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) | 
| 135 | 145 |                 verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) | 
|  | 146 | +        else: | 
|  | 147 | +            raise TypeError( | 
|  | 148 | +                f"Expected type str or Iterable for argument classes, but got type {type(classes).__name__}." | 
|  | 149 | +            ) | 
| 136 | 150 | 
 | 
| 137 | 151 |         return classes | 
| 138 | 152 | 
 | 
|  | 
0 commit comments