|
| 1 | +import glob |
| 2 | +import os |
| 3 | + |
| 4 | +from .video_utils import VideoClips |
| 5 | +from .utils import list_dir |
| 6 | +from .folder import make_dataset |
| 7 | +from .vision import VisionDataset |
| 8 | + |
| 9 | + |
| 10 | +class HMDB51(VisionDataset): |
| 11 | + |
| 12 | + data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar" |
| 13 | + splits = { |
| 14 | + "url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar", |
| 15 | + "md5": "15e67781e70dcfbdce2d7dbb9b3344b5" |
| 16 | + } |
| 17 | + |
| 18 | + def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, |
| 19 | + fold=1, train=True): |
| 20 | + super(HMDB51, self).__init__(root) |
| 21 | + extensions = ('avi',) |
| 22 | + self.fold = fold |
| 23 | + self.train = train |
| 24 | + |
| 25 | + classes = list(sorted(list_dir(root))) |
| 26 | + class_to_idx = {classes[i]: i for i in range(len(classes))} |
| 27 | + self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) |
| 28 | + self.classes = classes |
| 29 | + video_list = [x[0] for x in self.samples] |
| 30 | + video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) |
| 31 | + indices = self._select_fold(video_list, annotation_path, fold, train) |
| 32 | + self.video_clips = video_clips.subset(indices) |
| 33 | + |
| 34 | + def _select_fold(self, video_list, annotation_path, fold, train): |
| 35 | + target_tag = 1 if train else 2 |
| 36 | + name = "*test_split{}.txt".format(fold) |
| 37 | + files = glob.glob(os.path.join(annotation_path, name)) |
| 38 | + selected_files = [] |
| 39 | + for f in files: |
| 40 | + with open(f, "r") as fid: |
| 41 | + data = fid.readlines() |
| 42 | + data = [x.strip().split(" ") for x in data] |
| 43 | + data = [x[0] for x in data if int(x[1]) == target_tag] |
| 44 | + selected_files.extend(data) |
| 45 | + selected_files = set(selected_files) |
| 46 | + indices = [i for i in range(len(video_list)) if os.path.basename(video_list[i]) in selected_files] |
| 47 | + return indices |
| 48 | + |
| 49 | + def __len__(self): |
| 50 | + return self.video_clips.num_clips() |
| 51 | + |
| 52 | + def __getitem__(self, idx): |
| 53 | + video, audio, info, video_idx = self.video_clips.get_clip(idx) |
| 54 | + label = self.samples[video_idx][1] |
| 55 | + |
| 56 | + return video, audio, label |
0 commit comments