Skip to content

Commit 40c8330

Browse files
JMisteleRyanCao7Etang21
authored andcommitted
Fixed video labelling after subset call for HMDB51 dataset (hmdb51.py) (EDIT: UCF101 as well) (#1240)
* Fixed video labelling after subset for HMDB51 dataset * Fixed video labelling after subset for HMDB51 dataset Co-authored-by: Eric Tang <[email protected]> Co-authored-by: Ryan Cao <[email protected]> * UCF 101 Labeling fixes - Analogous fix to HMDB51 to maintain correct labels after the train-test split - Additional change to the `select_fold` method in `ucf101.py` to correctly reflect the annotation format Co-authored-by: Ryan Cao <[email protected]> Co-authored-by: Eric Tang <[email protected]>
1 parent 88ef5e1 commit 40c8330

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

torchvision/datasets/hmdb51.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
6565
self.classes = classes
6666
video_list = [x[0] for x in self.samples]
6767
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
68-
indices = self._select_fold(video_list, annotation_path, fold, train)
69-
self.video_clips = video_clips.subset(indices)
68+
self.indices = self._select_fold(video_list, annotation_path, fold, train)
69+
self.video_clips = video_clips.subset(self.indices)
7070
self.transform = transform
7171

7272
def _select_fold(self, video_list, annotation_path, fold, train):
@@ -89,7 +89,7 @@ def __len__(self):
8989

9090
def __getitem__(self, idx):
9191
video, audio, info, video_idx = self.video_clips.get_clip(idx)
92-
label = self.samples[video_idx][1]
92+
label = self.samples[self.indices[video_idx]][1]
9393

9494
if self.transform is not None:
9595
video = self.transform(video)

torchvision/datasets/ucf101.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
5858
self.classes = classes
5959
video_list = [x[0] for x in self.samples]
6060
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
61-
indices = self._select_fold(video_list, annotation_path, fold, train)
62-
self.video_clips = video_clips.subset(indices)
61+
self.indices = self._select_fold(video_list, annotation_path, fold, train)
62+
self.video_clips = video_clips.subset(self.indices)
6363
self.transform = transform
6464

6565
def _select_fold(self, video_list, annotation_path, fold, train):
@@ -81,7 +81,7 @@ def __len__(self):
8181

8282
def __getitem__(self, idx):
8383
video, audio, info, video_idx = self.video_clips.get_clip(idx)
84-
label = self.samples[video_idx][1]
84+
label = self.samples[self.indices[video_idx]][1]
8585

8686
if self.transform is not None:
8787
video = self.transform(video)

0 commit comments

Comments
 (0)