Skip to content

Commit 8b21b43

Browse files
committed
fix: check if the voc dataset folder exists before downloading.
1 parent 6bbe010 commit 8b21b43

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

torchvision/datasets/voc.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections
22
import os
33
from pathlib import Path
4-
from typing import Any, Callable, Optional, Union
4+
from typing import Any, Callable, Optional, Tuple, Union
55
from xml.etree.ElementTree import Element as ET_Element
66

77
try:
@@ -64,6 +64,8 @@ class _VOCBase(VisionDataset):
6464
_SPLITS_DIR: str
6565
_TARGET_DIR: str
6666
_TARGET_FILE_EXT: str
67+
_IMAGE_SET: str = "ImageSets"
68+
_IMAGE_DIR: str = "JPEGImages"
6769

6870
def __init__(
6971
self,
@@ -95,24 +97,38 @@ def __init__(
9597
voc_root = os.path.join(self.root, base_dir)
9698

9799
if download:
98-
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
100+
self._download(voc_root)
99101

100-
if not os.path.isdir(voc_root):
102+
if not self._check_exists(voc_root):
101103
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
102104

103-
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
105+
splits_dir, image_dir, target_dir = self._voc_subfolders(voc_root)
104106
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
105107
with open(os.path.join(split_f)) as f:
106108
file_names = [x.strip() for x in f.readlines()]
107109

108-
image_dir = os.path.join(voc_root, "JPEGImages")
109110
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
110111

111-
target_dir = os.path.join(voc_root, self._TARGET_DIR)
112112
self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
113113

114114
assert len(self.images) == len(self.targets)
115115

116+
def _voc_subfolders(self, voc_root) -> Tuple[str, str, str]:
117+
"""Returns the subfolders for the VOC dataset."""
118+
splits_dir = os.path.join(voc_root, self._IMAGE_SET, self._SPLITS_DIR)
119+
image_dir = os.path.join(voc_root, self._IMAGE_DIR)
120+
target_dir = os.path.join(voc_root, self._TARGET_DIR)
121+
return splits_dir, image_dir, target_dir
122+
123+
def _download(self, voc_root: str) -> None:
124+
if self._check_exists(voc_root):
125+
return
126+
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
127+
128+
def _check_exists(self, voc_root: str) -> bool:
129+
"""Check if the dataset exists."""
130+
return all(os.path.isdir(d) and len(os.listdir(d)) for d in self._voc_subfolders(voc_root))
131+
116132
def __len__(self) -> int:
117133
return len(self.images)
118134

0 commit comments

Comments
 (0)