@@ -273,72 +273,74 @@ def read_video(
273
273
raise RuntimeError (f"File not found: { filename } " )
274
274
275
275
if get_video_backend () != "pyav" :
276
- return _video_opt ._read_video (filename , start_pts , end_pts , pts_unit )
277
-
278
- _check_av_available ()
279
-
280
- if end_pts is None :
281
- end_pts = float ("inf" )
282
-
283
- if end_pts < start_pts :
284
- raise ValueError (f"end_pts should be larger than start_pts, got start_pts={ start_pts } and end_pts={ end_pts } " )
285
-
286
- info = {}
287
- video_frames = []
288
- audio_frames = []
289
- audio_timebase = _video_opt .default_timebase
290
-
291
- try :
292
- with av .open (filename , metadata_errors = "ignore" ) as container :
293
- if container .streams .audio :
294
- audio_timebase = container .streams .audio [0 ].time_base
295
- if container .streams .video :
296
- video_frames = _read_from_stream (
297
- container ,
298
- start_pts ,
299
- end_pts ,
300
- pts_unit ,
301
- container .streams .video [0 ],
302
- {"video" : 0 },
303
- )
304
- video_fps = container .streams .video [0 ].average_rate
305
- # guard against potentially corrupted files
306
- if video_fps is not None :
307
- info ["video_fps" ] = float (video_fps )
308
-
309
- if container .streams .audio :
310
- audio_frames = _read_from_stream (
311
- container ,
312
- start_pts ,
313
- end_pts ,
314
- pts_unit ,
315
- container .streams .audio [0 ],
316
- {"audio" : 0 },
317
- )
318
- info ["audio_fps" ] = container .streams .audio [0 ].rate
319
-
320
- except av .AVError :
321
- # TODO raise a warning?
322
- pass
323
-
324
- vframes_list = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
325
- aframes_list = [frame .to_ndarray () for frame in audio_frames ]
326
-
327
- if vframes_list :
328
- vframes = torch .as_tensor (np .stack (vframes_list ))
329
- else :
330
- vframes = torch .empty ((0 , 1 , 1 , 3 ), dtype = torch .uint8 )
331
-
332
- if aframes_list :
333
- aframes = np .concatenate (aframes_list , 1 )
334
- aframes = torch .as_tensor (aframes )
335
- if pts_unit == "sec" :
336
- start_pts = int (math .floor (start_pts * (1 / audio_timebase )))
337
- if end_pts != float ("inf" ):
338
- end_pts = int (math .ceil (end_pts * (1 / audio_timebase )))
339
- aframes = _align_audio_frames (aframes , audio_frames , start_pts , end_pts )
276
+ vframes , aframes , info = _video_opt ._read_video (filename , start_pts , end_pts , pts_unit )
340
277
else :
341
- aframes = torch .empty ((1 , 0 ), dtype = torch .float32 )
278
+ _check_av_available ()
279
+
280
+ if end_pts is None :
281
+ end_pts = float ("inf" )
282
+
283
+ if end_pts < start_pts :
284
+ raise ValueError (
285
+ f"end_pts should be larger than start_pts, got start_pts={ start_pts } and end_pts={ end_pts } "
286
+ )
287
+
288
+ info = {}
289
+ video_frames = []
290
+ audio_frames = []
291
+ audio_timebase = _video_opt .default_timebase
292
+
293
+ try :
294
+ with av .open (filename , metadata_errors = "ignore" ) as container :
295
+ if container .streams .audio :
296
+ audio_timebase = container .streams .audio [0 ].time_base
297
+ if container .streams .video :
298
+ video_frames = _read_from_stream (
299
+ container ,
300
+ start_pts ,
301
+ end_pts ,
302
+ pts_unit ,
303
+ container .streams .video [0 ],
304
+ {"video" : 0 },
305
+ )
306
+ video_fps = container .streams .video [0 ].average_rate
307
+ # guard against potentially corrupted files
308
+ if video_fps is not None :
309
+ info ["video_fps" ] = float (video_fps )
310
+
311
+ if container .streams .audio :
312
+ audio_frames = _read_from_stream (
313
+ container ,
314
+ start_pts ,
315
+ end_pts ,
316
+ pts_unit ,
317
+ container .streams .audio [0 ],
318
+ {"audio" : 0 },
319
+ )
320
+ info ["audio_fps" ] = container .streams .audio [0 ].rate
321
+
322
+ except av .AVError :
323
+ # TODO raise a warning?
324
+ pass
325
+
326
+ vframes_list = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
327
+ aframes_list = [frame .to_ndarray () for frame in audio_frames ]
328
+
329
+ if vframes_list :
330
+ vframes = torch .as_tensor (np .stack (vframes_list ))
331
+ else :
332
+ vframes = torch .empty ((0 , 1 , 1 , 3 ), dtype = torch .uint8 )
333
+
334
+ if aframes_list :
335
+ aframes = np .concatenate (aframes_list , 1 )
336
+ aframes = torch .as_tensor (aframes )
337
+ if pts_unit == "sec" :
338
+ start_pts = int (math .floor (start_pts * (1 / audio_timebase )))
339
+ if end_pts != float ("inf" ):
340
+ end_pts = int (math .ceil (end_pts * (1 / audio_timebase )))
341
+ aframes = _align_audio_frames (aframes , audio_frames , start_pts , end_pts )
342
+ else :
343
+ aframes = torch .empty ((1 , 0 ), dtype = torch .float32 )
342
344
343
345
if output_format == "TCHW" :
344
346
# [T,H,W,C] --> [T,C,H,W]
0 commit comments