Skip to content

Commit d9830d8

Browse files
authored
Add HMDB51 and UCF101 datasets (#1156)
* Add HMDB51 and UCF101 * Remove debug code
1 parent 010984d commit d9830d8

File tree

3 files changed

+108
-1
lines changed

3 files changed

+108
-1
lines changed

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from .vision import VisionDataset
2121
from .usps import USPS
2222
from .kinetics import KineticsVideo
23+
from .hmdb51 import HMDB51
24+
from .ucf101 import UCF101
2325

2426
__all__ = ('LSUN', 'LSUNClass',
2527
'ImageFolder', 'DatasetFolder', 'FakeData',
@@ -29,4 +31,4 @@
2931
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
3032
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
3133
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
32-
'USPS', 'KineticsVideo')
34+
'USPS', 'KineticsVideo', 'HMDB51', 'UCF101')

torchvision/datasets/hmdb51.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import glob
2+
import os
3+
4+
from .video_utils import VideoClips
5+
from .utils import list_dir
6+
from .folder import make_dataset
7+
from .vision import VisionDataset
8+
9+
10+
class HMDB51(VisionDataset):
11+
12+
data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
13+
splits = {
14+
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
15+
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5"
16+
}
17+
18+
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
19+
fold=1, train=True):
20+
super(HMDB51, self).__init__(root)
21+
extensions = ('avi',)
22+
self.fold = fold
23+
self.train = train
24+
25+
classes = list(sorted(list_dir(root)))
26+
class_to_idx = {classes[i]: i for i in range(len(classes))}
27+
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
28+
self.classes = classes
29+
video_list = [x[0] for x in self.samples]
30+
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
31+
indices = self._select_fold(video_list, annotation_path, fold, train)
32+
self.video_clips = video_clips.subset(indices)
33+
34+
def _select_fold(self, video_list, annotation_path, fold, train):
35+
target_tag = 1 if train else 2
36+
name = "*test_split{}.txt".format(fold)
37+
files = glob.glob(os.path.join(annotation_path, name))
38+
selected_files = []
39+
for f in files:
40+
with open(f, "r") as fid:
41+
data = fid.readlines()
42+
data = [x.strip().split(" ") for x in data]
43+
data = [x[0] for x in data if int(x[1]) == target_tag]
44+
selected_files.extend(data)
45+
selected_files = set(selected_files)
46+
indices = [i for i in range(len(video_list)) if os.path.basename(video_list[i]) in selected_files]
47+
return indices
48+
49+
def __len__(self):
50+
return self.video_clips.num_clips()
51+
52+
def __getitem__(self, idx):
53+
video, audio, info, video_idx = self.video_clips.get_clip(idx)
54+
label = self.samples[video_idx][1]
55+
56+
return video, audio, label

torchvision/datasets/ucf101.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import glob
2+
import os
3+
4+
from .video_utils import VideoClips
5+
from .utils import list_dir
6+
from .folder import make_dataset
7+
from .vision import VisionDataset
8+
9+
10+
class UCF101(VisionDataset):
11+
12+
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
13+
fold=1, train=True):
14+
super(UCF101, self).__init__(root)
15+
extensions = ('avi',)
16+
self.fold = fold
17+
self.train = train
18+
19+
classes = list(sorted(list_dir(root)))
20+
class_to_idx = {classes[i]: i for i in range(len(classes))}
21+
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
22+
self.classes = classes
23+
video_list = [x[0] for x in self.samples]
24+
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
25+
indices = self._select_fold(video_list, annotation_path, fold, train)
26+
self.video_clips = video_clips.subset(indices)
27+
28+
def _select_fold(self, video_list, annotation_path, fold, train):
29+
name = "train" if train else "test"
30+
name = "{}list{:02d}.txt".format(name, fold)
31+
f = os.path.join(annotation_path, name)
32+
selected_files = []
33+
with open(f, "r") as fid:
34+
data = fid.readlines()
35+
data = [x.strip().split(" ") for x in data]
36+
data = [x[0] for x in data]
37+
selected_files.extend(data)
38+
selected_files = set(selected_files)
39+
indices = [i for i in range(len(video_list)) if video_list[i][len(self.root) + 1:] in selected_files]
40+
return indices
41+
42+
def __len__(self):
43+
return self.video_clips.num_clips()
44+
45+
def __getitem__(self, idx):
46+
video, audio, info, video_idx = self.video_clips.get_clip(idx)
47+
label = self.samples[video_idx][1]
48+
49+
return video, audio, label

0 commit comments

Comments
 (0)