Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions torchvision/datasets/places365.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from os import path
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from urllib.parse import urljoin

from .folder import default_loader
Expand All @@ -15,7 +15,7 @@ class Places365(VisionDataset):
Args:
root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
``val``.
``val``, ``test``.
small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
Expand All @@ -36,7 +36,8 @@ class Places365(VisionDataset):
RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted.
"""
_SPLITS = ("train-standard", "train-challenge", "val")

_SPLITS = ("train-standard", "train-challenge", "val", "test")
_BASE_URL = "http://data.csail.mit.edu/places/places365/"
# {variant: (archive, md5)}
_DEVKIT_META = {
Expand All @@ -50,15 +51,18 @@ class Places365(VisionDataset):
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
"test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
}
# {(split, small): (file, md5)}
_IMAGES_META = {
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
}

def __init__(
Expand Down Expand Up @@ -123,10 +127,14 @@ def process(line: str) -> Tuple[str, int]:

return sorted(class_to_idx.keys()), class_to_idx

def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
def process(line: str, sep="/") -> Tuple[str, int]:
image, idx = line.split()
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
def load_file_list(
self, download: bool = True
) -> Tuple[List[Tuple[str, Union[int, None]]], List[Union[int, None]]]:
def process(line: str, sep="/") -> Tuple[str, Union[int, None]]:
image, idx = (line.split() + [None])[:2]
image = cast(str, image)
idx = int(idx) if idx is not None else None
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), idx

file, md5 = self._FILE_LIST_META[self.split]
file = path.join(self.root, file)
Expand Down