|
1 | 1 | import collections |
2 | 2 | import os |
3 | 3 | from pathlib import Path |
4 | | -from typing import Any, Callable, Optional, Union |
| 4 | +from typing import Any, Callable, Optional, Tuple, Union |
5 | 5 | from xml.etree.ElementTree import Element as ET_Element |
6 | 6 |
|
7 | 7 | try: |
@@ -64,6 +64,8 @@ class _VOCBase(VisionDataset): |
64 | 64 | _SPLITS_DIR: str |
65 | 65 | _TARGET_DIR: str |
66 | 66 | _TARGET_FILE_EXT: str |
| 67 | + _IMAGE_SET: str = "ImageSets" |
| 68 | + _IMAGE_DIR: str = "JPEGImages" |
67 | 69 |
|
68 | 70 | def __init__( |
69 | 71 | self, |
@@ -95,24 +97,38 @@ def __init__( |
95 | 97 | voc_root = os.path.join(self.root, base_dir) |
96 | 98 |
|
97 | 99 | if download: |
98 | | - download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) |
| 100 | + self._download(voc_root) |
99 | 101 |
|
100 | | - if not os.path.isdir(voc_root): |
| 102 | + if not self._check_exists(voc_root): |
101 | 103 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") |
102 | 104 |
|
103 | | - splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) |
| 105 | + splits_dir, image_dir, target_dir = self._voc_subfolders(voc_root) |
104 | 106 | split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") |
105 | 107 | with open(os.path.join(split_f)) as f: |
106 | 108 | file_names = [x.strip() for x in f.readlines()] |
107 | 109 |
|
108 | | - image_dir = os.path.join(voc_root, "JPEGImages") |
109 | 110 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] |
110 | 111 |
|
111 | | - target_dir = os.path.join(voc_root, self._TARGET_DIR) |
112 | 112 | self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names] |
113 | 113 |
|
114 | 114 | assert len(self.images) == len(self.targets) |
115 | 115 |
|
| 116 | + def _voc_subfolders(self, voc_root) -> Tuple[str, str, str]: |
| 117 | + """Returns the subfolders for the VOC dataset.""" |
| 118 | + splits_dir = os.path.join(voc_root, self._IMAGE_SET, self._SPLITS_DIR) |
| 119 | + image_dir = os.path.join(voc_root, self._IMAGE_DIR) |
| 120 | + target_dir = os.path.join(voc_root, self._TARGET_DIR) |
| 121 | + return splits_dir, image_dir, target_dir |
| 122 | + |
| 123 | + def _download(self, voc_root: str) -> None: |
| 124 | + if self._check_exists(voc_root): |
| 125 | + return |
| 126 | + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) |
| 127 | + |
| 128 | + def _check_exists(self, voc_root: str) -> bool: |
| 129 | + """Check if the dataset exists.""" |
| 130 | + return all(os.path.isdir(d) and len(os.listdir(d)) for d in self._voc_subfolders(voc_root)) |
| 131 | + |
116 | 132 | def __len__(self) -> int: |
117 | 133 | return len(self.images) |
118 | 134 |
|
|
0 commit comments