55from torchvision .io import (
66 _read_video_timestamps_from_file ,
77 _read_video_from_file ,
8+ _probe_video_from_file
89)
910from torchvision .io import read_video_timestamps , read_video
1011
@@ -71,11 +72,11 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
7172 frame_rate = None , _precomputed_metadata = None , num_workers = 0 ,
7273 _video_width = 0 , _video_height = 0 , _video_min_dimension = 0 ,
7374 _audio_samples = 0 ):
74- from torchvision import get_video_backend
7575
7676 self .video_paths = video_paths
7777 self .num_workers = num_workers
78- self ._backend = get_video_backend ()
78+
79+ # these options are not valid for pyav backend
7980 self ._video_width = _video_width
8081 self ._video_height = _video_height
8182 self ._video_min_dimension = _video_min_dimension
@@ -89,87 +90,60 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
8990
9091 def _compute_frame_pts (self ):
9192 self .video_pts = []
92- if self ._backend == "pyav" :
93- self .video_fps = []
94- else :
95- self .info = []
93+ self .video_fps = []
9694
9795 # strategy: use a DataLoader to parallelize read_video_timestamps
9896 # so need to create a dummy dataset first
9997 class DS (object ):
100- def __init__ (self , x , _backend ):
98+ def __init__ (self , x ):
10199 self .x = x
102- self ._backend = _backend
103100
104101 def __len__ (self ):
105102 return len (self .x )
106103
107104 def __getitem__ (self , idx ):
108- if self ._backend == "pyav" :
109- return read_video_timestamps (self .x [idx ])
110- else :
111- return _read_video_timestamps_from_file (self .x [idx ])
105+ return read_video_timestamps (self .x [idx ])
112106
113107 import torch .utils .data
114108 dl = torch .utils .data .DataLoader (
115- DS (self .video_paths , self . _backend ),
109+ DS (self .video_paths ),
116110 batch_size = 16 ,
117111 num_workers = self .num_workers ,
118112 collate_fn = lambda x : x )
119113
120114 with tqdm (total = len (dl )) as pbar :
121115 for batch in dl :
122116 pbar .update (1 )
123- if self ._backend == "pyav" :
124- clips , fps = list (zip (* batch ))
125- clips = [torch .as_tensor (c ) for c in clips ]
126- self .video_pts .extend (clips )
127- self .video_fps .extend (fps )
128- else :
129- video_pts , _audio_pts , info = list (zip (* batch ))
130- video_pts = [torch .as_tensor (c ) for c in video_pts ]
131- self .video_pts .extend (video_pts )
132- self .info .extend (info )
117+ clips , fps = list (zip (* batch ))
118+ clips = [torch .as_tensor (c ) for c in clips ]
119+ self .video_pts .extend (clips )
120+ self .video_fps .extend (fps )
133121
134122 def _init_from_metadata (self , metadata ):
135123 self .video_paths = metadata ["video_paths" ]
136124 assert len (self .video_paths ) == len (metadata ["video_pts" ])
137125 self .video_pts = metadata ["video_pts" ]
138-
139- if self ._backend == "pyav" :
140- assert len (self .video_paths ) == len (metadata ["video_fps" ])
141- self .video_fps = metadata ["video_fps" ]
142- else :
143- assert len (self .video_paths ) == len (metadata ["info" ])
144- self .info = metadata ["info" ]
126+ assert len (self .video_paths ) == len (metadata ["video_fps" ])
127+ self .video_fps = metadata ["video_fps" ]
145128
146129 @property
147130 def metadata (self ):
148131 _metadata = {
149132 "video_paths" : self .video_paths ,
150133 "video_pts" : self .video_pts ,
134+ "video_fps" : self .video_fps
151135 }
152- if self ._backend == "pyav" :
153- _metadata .update ({"video_fps" : self .video_fps })
154- else :
155- _metadata .update ({"info" : self .info })
156136 return _metadata
157137
158138 def subset (self , indices ):
159139 video_paths = [self .video_paths [i ] for i in indices ]
160140 video_pts = [self .video_pts [i ] for i in indices ]
161- if self ._backend == "pyav" :
162- video_fps = [self .video_fps [i ] for i in indices ]
163- else :
164- info = [self .info [i ] for i in indices ]
141+ video_fps = [self .video_fps [i ] for i in indices ]
165142 metadata = {
166143 "video_paths" : video_paths ,
167144 "video_pts" : video_pts ,
145+ "video_fps" : video_fps
168146 }
169- if self ._backend == "pyav" :
170- metadata .update ({"video_fps" : video_fps })
171- else :
172- metadata .update ({"info" : info })
173147 return type (self )(video_paths , self .num_frames , self .step , self .frame_rate ,
174148 _precomputed_metadata = metadata , num_workers = self .num_workers ,
175149 _video_width = self ._video_width ,
@@ -212,22 +186,10 @@ def compute_clips(self, num_frames, step, frame_rate=None):
212186 self .frame_rate = frame_rate
213187 self .clips = []
214188 self .resampling_idxs = []
215- if self ._backend == "pyav" :
216- for video_pts , fps in zip (self .video_pts , self .video_fps ):
217- clips , idxs = self .compute_clips_for_video (video_pts , num_frames , step , fps , frame_rate )
218- self .clips .append (clips )
219- self .resampling_idxs .append (idxs )
220- else :
221- for video_pts , info in zip (self .video_pts , self .info ):
222- if "video_fps" in info :
223- clips , idxs = self .compute_clips_for_video (
224- video_pts , num_frames , step , info ["video_fps" ], frame_rate )
225- self .clips .append (clips )
226- self .resampling_idxs .append (idxs )
227- else :
228- # properly handle the cases where video decoding fails
229- self .clips .append (torch .zeros (0 , num_frames , dtype = torch .int64 ))
230- self .resampling_idxs .append (torch .zeros (0 , dtype = torch .int64 ))
189+ for video_pts , fps in zip (self .video_pts , self .video_fps ):
190+ clips , idxs = self .compute_clips_for_video (video_pts , num_frames , step , fps , frame_rate )
191+ self .clips .append (clips )
192+ self .resampling_idxs .append (idxs )
231193 clip_lengths = torch .as_tensor ([len (v ) for v in self .clips ])
232194 self .cumulative_sizes = clip_lengths .cumsum (0 ).tolist ()
233195
@@ -287,12 +249,28 @@ def get_clip(self, idx):
287249 video_path = self .video_paths [video_idx ]
288250 clip_pts = self .clips [video_idx ][clip_idx ]
289251
290- if self ._backend == "pyav" :
252+ from torchvision import get_video_backend
253+ backend = get_video_backend ()
254+
255+ if backend == "pyav" :
256+ # check for invalid options
257+ if self ._video_width != 0 :
258+ raise ValueError ("pyav backend doesn't support _video_width != 0" )
259+ if self ._video_height != 0 :
260+ raise ValueError ("pyav backend doesn't support _video_height != 0" )
261+ if self ._video_min_dimension != 0 :
262+ raise ValueError ("pyav backend doesn't support _video_min_dimension != 0" )
263+ if self ._audio_samples != 0 :
264+ raise ValueError ("pyav backend doesn't support _audio_samples != 0" )
265+
266+ if backend == "pyav" :
291267 start_pts = clip_pts [0 ].item ()
292268 end_pts = clip_pts [- 1 ].item ()
293269 video , audio , info = read_video (video_path , start_pts , end_pts )
294270 else :
295- info = self .info [video_idx ]
271+ info = _probe_video_from_file (video_path )
272+ video_fps = info ["video_fps" ]
273+ audio_fps = None
296274
297275 video_start_pts = clip_pts [0 ].item ()
298276 video_end_pts = clip_pts [- 1 ].item ()
@@ -313,6 +291,7 @@ def get_clip(self, idx):
313291 info ["audio_timebase" ],
314292 math .ceil ,
315293 )
294+ audio_fps = info ["audio_sample_rate" ]
316295 video , audio , info = _read_video_from_file (
317296 video_path ,
318297 video_width = self ._video_width ,
@@ -324,6 +303,11 @@ def get_clip(self, idx):
324303 audio_pts_range = (audio_start_pts , audio_end_pts ),
325304 audio_timebase = audio_timebase ,
326305 )
306+
307+ info = {"video_fps" : video_fps }
308+ if audio_fps is not None :
309+ info ["audio_fps" ] = audio_fps
310+
327311 if self .frame_rate is not None :
328312 resampling_idx = self .resampling_idxs [video_idx ][clip_idx ]
329313 if isinstance (resampling_idx , torch .Tensor ):
0 commit comments