Skip to content

Commit 94c9417

Browse files
authored
Move RandomClipSampler to references (#1186)
* Move RandomClipSampler to references * Lint and bugfix
1 parent fe4d17f commit 94c9417

File tree

4 files changed

+42
-41
lines changed

4 files changed

+42
-41
lines changed

references/video_classification/sampler.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,38 @@ def __iter__(self):
8787

8888
def __len__(self):
8989
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
90+
91+
92+
class RandomClipSampler(torch.utils.data.Sampler):
93+
"""
94+
Samples at most `max_video_clips_per_video` clips for each video randomly
95+
96+
Arguments:
97+
video_clips (VideoClips): video clips to sample from
98+
max_clips_per_video (int): maximum number of clips to be sampled per video
99+
"""
100+
def __init__(self, video_clips, max_clips_per_video):
101+
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
102+
raise TypeError("Expected video_clips to be an instance of VideoClips, "
103+
"got {}".format(type(video_clips)))
104+
self.video_clips = video_clips
105+
self.max_clips_per_video = max_clips_per_video
106+
107+
def __iter__(self):
108+
idxs = []
109+
s = 0
110+
# select at most max_clips_per_video for each video, randomly
111+
for c in self.video_clips.clips:
112+
length = len(c)
113+
size = min(length, self.max_clips_per_video)
114+
sampled = torch.randperm(length)[:size] + s
115+
s += length
116+
idxs.append(sampled)
117+
idxs = torch.cat(idxs)
118+
# shuffle all clips randomly
119+
perm = torch.randperm(len(idxs))
120+
idxs = idxs[perm].tolist()
121+
return iter(idxs)
122+
123+
def __len__(self):
124+
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)

references/video_classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchvision import transforms
1414

1515
import utils
16-
from sampler import DistributedSampler, UniformClipSampler
16+
from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
1717
from scheduler import WarmupMultiStepLR
1818
import transforms as T
1919

@@ -184,7 +184,7 @@ def main(args):
184184
dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)
185185

186186
print("Creating data loaders")
187-
train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video)
187+
train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
188188
test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
189189
if args.distributed:
190190
train_sampler = DistributedSampler(train_sampler)

test/test_datasets_video_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55

66
from torchvision import io
7-
from torchvision.datasets.video_utils import VideoClips, unfold, RandomClipSampler
7+
from torchvision.datasets.video_utils import VideoClips, unfold
88

99
from common_utils import get_tmp_dir
1010

@@ -80,21 +80,23 @@ def test_video_clips(self):
8080
self.assertEqual(video_idx, v_idx)
8181
self.assertEqual(clip_idx, c_idx)
8282

83+
@unittest.skip("Moved to reference scripts for now")
8384
def test_video_sampler(self):
8485
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
8586
video_clips = VideoClips(video_list, 5, 5)
86-
sampler = RandomClipSampler(video_clips, 3)
87+
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
8788
self.assertEqual(len(sampler), 3 * 3)
8889
indices = torch.tensor(list(iter(sampler)))
8990
videos = indices // 5
9091
v_idxs, count = torch.unique(videos, return_counts=True)
9192
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
9293
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
9394

95+
@unittest.skip("Moved to reference scripts for now")
9496
def test_video_sampler_unequal(self):
9597
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
9698
video_clips = VideoClips(video_list, 5, 5)
97-
sampler = RandomClipSampler(video_clips, 3)
99+
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
98100
self.assertEqual(len(sampler), 2 + 3 + 3)
99101
indices = list(iter(sampler))
100102
self.assertIn(0, indices)

torchvision/datasets/video_utils.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import bisect
22
import math
33
import torch
4-
import torch.utils.data
54
from torchvision.io import read_video_timestamps, read_video
65

76
from .utils import tqdm
@@ -214,38 +213,3 @@ def get_clip(self, idx):
214213
info["video_fps"] = self.frame_rate
215214
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
216215
return video, audio, info, video_idx
217-
218-
219-
class RandomClipSampler(torch.utils.data.Sampler):
220-
"""
221-
Samples at most `max_video_clips_per_video` clips for each video randomly
222-
223-
Arguments:
224-
video_clips (VideoClips): video clips to sample from
225-
max_clips_per_video (int): maximum number of clips to be sampled per video
226-
"""
227-
def __init__(self, video_clips, max_clips_per_video):
228-
if not isinstance(video_clips, VideoClips):
229-
raise TypeError("Expected video_clips to be an instance of VideoClips, "
230-
"got {}".format(type(video_clips)))
231-
self.video_clips = video_clips
232-
self.max_clips_per_video = max_clips_per_video
233-
234-
def __iter__(self):
235-
idxs = []
236-
s = 0
237-
# select at most max_clips_per_video for each video, randomly
238-
for c in self.video_clips.clips:
239-
length = len(c)
240-
size = min(length, self.max_clips_per_video)
241-
sampled = torch.randperm(length)[:size] + s
242-
s += length
243-
idxs.append(sampled)
244-
idxs = torch.cat(idxs)
245-
# shuffle all clips randomly
246-
perm = torch.randperm(len(idxs))
247-
idxs = idxs[perm].tolist()
248-
return iter(idxs)
249-
250-
def __len__(self):
251-
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)

0 commit comments

Comments
 (0)