5
5
6
6
from torchvision import io
7
7
from torchvision .datasets .video_utils import VideoClips , unfold
8
- from torchvision import get_video_backend
9
8
10
9
from common_utils import get_tmp_dir
11
10
@@ -60,23 +59,22 @@ def test_unfold(self):
60
59
61
60
@unittest .skipIf (not io .video ._av_available (), "this test requires av" )
62
61
def test_video_clips (self ):
63
- _backend = get_video_backend ()
64
62
with get_list_of_videos (num_videos = 3 ) as video_list :
65
- video_clips = VideoClips (video_list , 5 , 5 , _backend = _backend )
63
+ video_clips = VideoClips (video_list , 5 , 5 )
66
64
self .assertEqual (video_clips .num_clips (), 1 + 2 + 3 )
67
65
for i , (v_idx , c_idx ) in enumerate ([(0 , 0 ), (1 , 0 ), (1 , 1 ), (2 , 0 ), (2 , 1 ), (2 , 2 )]):
68
66
video_idx , clip_idx = video_clips .get_clip_location (i )
69
67
self .assertEqual (video_idx , v_idx )
70
68
self .assertEqual (clip_idx , c_idx )
71
69
72
- video_clips = VideoClips (video_list , 6 , 6 , _backend = _backend )
70
+ video_clips = VideoClips (video_list , 6 , 6 )
73
71
self .assertEqual (video_clips .num_clips (), 0 + 1 + 2 )
74
72
for i , (v_idx , c_idx ) in enumerate ([(1 , 0 ), (2 , 0 ), (2 , 1 )]):
75
73
video_idx , clip_idx = video_clips .get_clip_location (i )
76
74
self .assertEqual (video_idx , v_idx )
77
75
self .assertEqual (clip_idx , c_idx )
78
76
79
- video_clips = VideoClips (video_list , 6 , 1 , _backend = _backend )
77
+ video_clips = VideoClips (video_list , 6 , 1 )
80
78
self .assertEqual (video_clips .num_clips (), 0 + (10 - 6 + 1 ) + (15 - 6 + 1 ))
81
79
for i , v_idx , c_idx in [(0 , 1 , 0 ), (4 , 1 , 4 ), (5 , 2 , 0 ), (6 , 2 , 1 )]:
82
80
video_idx , clip_idx = video_clips .get_clip_location (i )
@@ -85,9 +83,8 @@ def test_video_clips(self):
85
83
86
84
@unittest .skip ("Moved to reference scripts for now" )
87
85
def test_video_sampler (self ):
88
- _backend = get_video_backend ()
89
86
with get_list_of_videos (num_videos = 3 , sizes = [25 , 25 , 25 ]) as video_list :
90
- video_clips = VideoClips (video_list , 5 , 5 , _backend = _backend )
87
+ video_clips = VideoClips (video_list , 5 , 5 )
91
88
sampler = RandomClipSampler (video_clips , 3 ) # noqa: F821
92
89
self .assertEqual (len (sampler ), 3 * 3 )
93
90
indices = torch .tensor (list (iter (sampler )))
@@ -98,9 +95,8 @@ def test_video_sampler(self):
98
95
99
96
@unittest .skip ("Moved to reference scripts for now" )
100
97
def test_video_sampler_unequal (self ):
101
- _backend = get_video_backend ()
102
98
with get_list_of_videos (num_videos = 3 , sizes = [10 , 25 , 25 ]) as video_list :
103
- video_clips = VideoClips (video_list , 5 , 5 , _backend = _backend )
99
+ video_clips = VideoClips (video_list , 5 , 5 )
104
100
sampler = RandomClipSampler (video_clips , 3 ) # noqa: F821
105
101
self .assertEqual (len (sampler ), 2 + 3 + 3 )
106
102
indices = list (iter (sampler ))
@@ -117,11 +113,10 @@ def test_video_sampler_unequal(self):
117
113
118
114
@unittest .skipIf (not io .video ._av_available (), "this test requires av" )
119
115
def test_video_clips_custom_fps (self ):
120
- _backend = get_video_backend ()
121
116
with get_list_of_videos (num_videos = 3 , sizes = [12 , 12 , 12 ], fps = [3 , 4 , 6 ]) as video_list :
122
117
num_frames = 4
123
118
for fps in [1 , 3 , 4 , 10 ]:
124
- video_clips = VideoClips (video_list , num_frames , num_frames , fps , _backend = _backend )
119
+ video_clips = VideoClips (video_list , num_frames , num_frames , fps )
125
120
for i in range (video_clips .num_clips ()):
126
121
video , audio , info , video_idx = video_clips .get_clip (i )
127
122
self .assertEqual (video .shape [0 ], num_frames )
0 commit comments