diff --git a/torchvision/datasets/places365.py b/torchvision/datasets/places365.py index a120e0e217a..c02fccb4154 100644 --- a/torchvision/datasets/places365.py +++ b/torchvision/datasets/places365.py @@ -1,7 +1,7 @@ import os from os import path from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from urllib.parse import urljoin from .folder import default_loader @@ -15,7 +15,7 @@ class Places365(VisionDataset): Args: root (str or ``pathlib.Path``): Root directory of the Places365 dataset. split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``, - ``val``. + ``val``, ``test``. small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the high resolution ones. download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already @@ -36,7 +36,8 @@ class Places365(VisionDataset): RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted. RuntimeError: If ``download is True`` and the image archive is already extracted. """ - _SPLITS = ("train-standard", "train-challenge", "val") + + _SPLITS = ("train-standard", "train-challenge", "val", "test") _BASE_URL = "http://data.csail.mit.edu/places/places365/" # {variant: (archive, md5)} _DEVKIT_META = { @@ -50,15 +51,18 @@ class Places365(VisionDataset): "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"), "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"), "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"), + "test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"), } # {(split, small): (file, md5)} _IMAGES_META = { ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"), ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"), ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"), + ("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"), ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"), ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"), ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"), + ("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"), } def __init__( @@ -123,10 +127,14 @@ def process(line: str) -> Tuple[str, int]: return sorted(class_to_idx.keys()), class_to_idx - def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]: - def process(line: str, sep="/") -> Tuple[str, int]: - image, idx = line.split() - return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx) + def load_file_list( + self, download: bool = True + ) -> Tuple[List[Tuple[str, Union[int, None]]], List[Union[int, None]]]: + def process(line: str, sep="/") -> Tuple[str, Union[int, None]]: + image, idx = (line.split() + [None])[:2] + image = cast(str, image) + idx = int(idx) if idx is not None else None + return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), idx file, md5 = self._FILE_LIST_META[self.split] file = path.join(self.root, file)