@@ -70,27 +70,23 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
70
70
if isinstance (fps , float ):
71
71
fps = np .round (fps )
72
72
73
- container = av .open (filename , mode = "w" )
74
-
75
- stream = container .add_stream (video_codec , rate = fps )
76
- stream .width = video_array .shape [2 ]
77
- stream .height = video_array .shape [1 ]
78
- stream .pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
79
- stream .options = options or {}
80
-
81
- for img in video_array :
82
- frame = av .VideoFrame .from_ndarray (img , format = "rgb24" )
83
- frame .pict_type = "NONE"
84
- for packet in stream .encode (frame ):
73
+ with av .open (filename , mode = "w" ) as container :
74
+ stream = container .add_stream (video_codec , rate = fps )
75
+ stream .width = video_array .shape [2 ]
76
+ stream .height = video_array .shape [1 ]
77
+ stream .pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
78
+ stream .options = options or {}
79
+
80
+ for img in video_array :
81
+ frame = av .VideoFrame .from_ndarray (img , format = "rgb24" )
82
+ frame .pict_type = "NONE"
83
+ for packet in stream .encode (frame ):
84
+ container .mux (packet )
85
+
86
+ # Flush stream
87
+ for packet in stream .encode ():
85
88
container .mux (packet )
86
89
87
- # Flush stream
88
- for packet in stream .encode ():
89
- container .mux (packet )
90
-
91
- # Close the file
92
- container .close ()
93
-
94
90
95
91
def _read_from_stream (
96
92
container , start_offset , end_offset , pts_unit , stream , stream_name
@@ -234,37 +230,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
234
230
audio_frames = []
235
231
236
232
try :
237
- container = av .open (filename , metadata_errors = "ignore" )
233
+ with av .open (filename , metadata_errors = "ignore" ) as container :
234
+ if container .streams .video :
235
+ video_frames = _read_from_stream (
236
+ container ,
237
+ start_pts ,
238
+ end_pts ,
239
+ pts_unit ,
240
+ container .streams .video [0 ],
241
+ {"video" : 0 },
242
+ )
243
+ video_fps = container .streams .video [0 ].average_rate
244
+ # guard against potentially corrupted files
245
+ if video_fps is not None :
246
+ info ["video_fps" ] = float (video_fps )
247
+
248
+ if container .streams .audio :
249
+ audio_frames = _read_from_stream (
250
+ container ,
251
+ start_pts ,
252
+ end_pts ,
253
+ pts_unit ,
254
+ container .streams .audio [0 ],
255
+ {"audio" : 0 },
256
+ )
257
+ info ["audio_fps" ] = container .streams .audio [0 ].rate
258
+
238
259
except av .AVError :
239
260
# TODO raise a warning?
240
261
pass
241
- else :
242
- if container .streams .video :
243
- video_frames = _read_from_stream (
244
- container ,
245
- start_pts ,
246
- end_pts ,
247
- pts_unit ,
248
- container .streams .video [0 ],
249
- {"video" : 0 },
250
- )
251
- video_fps = container .streams .video [0 ].average_rate
252
- # guard against potentially corrupted files
253
- if video_fps is not None :
254
- info ["video_fps" ] = float (video_fps )
255
-
256
- if container .streams .audio :
257
- audio_frames = _read_from_stream (
258
- container ,
259
- start_pts ,
260
- end_pts ,
261
- pts_unit ,
262
- container .streams .audio [0 ],
263
- {"audio" : 0 },
264
- )
265
- info ["audio_fps" ] = container .streams .audio [0 ].rate
266
-
267
- container .close ()
268
262
269
263
vframes = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
270
264
aframes = [frame .to_ndarray () for frame in audio_frames ]
@@ -293,6 +287,14 @@ def _can_read_timestamps_from_packets(container):
293
287
return False
294
288
295
289
290
+ def _decode_video_timestamps (container ):
291
+ if _can_read_timestamps_from_packets (container ):
292
+ # fast path
293
+ return [x .pts for x in container .demux (video = 0 ) if x .pts is not None ]
294
+ else :
295
+ return [x .pts for x in container .decode (video = 0 ) if x .pts is not None ]
296
+
297
+
296
298
def read_video_timestamps (filename , pts_unit = "pts" ):
297
299
"""
298
300
List the video frames timestamps.
@@ -326,26 +328,18 @@ def read_video_timestamps(filename, pts_unit="pts"):
326
328
pts = []
327
329
328
330
try :
329
- container = av .open (filename , metadata_errors = "ignore" )
331
+ with av .open (filename , metadata_errors = "ignore" ) as container :
332
+ if container .streams .video :
333
+ video_stream = container .streams .video [0 ]
334
+ video_time_base = video_stream .time_base
335
+ try :
336
+ pts = _decode_video_timestamps (container )
337
+ except av .AVError :
338
+ warnings .warn (f"Failed decoding frames for file { filename } " )
339
+ video_fps = float (video_stream .average_rate )
330
340
except av .AVError :
331
341
# TODO add a warning
332
342
pass
333
- else :
334
- if container .streams .video :
335
- video_stream = container .streams .video [0 ]
336
- video_time_base = video_stream .time_base
337
- try :
338
- if _can_read_timestamps_from_packets (container ):
339
- # fast path
340
- pts = [x .pts for x in container .demux (video = 0 ) if x .pts is not None ]
341
- else :
342
- pts = [
343
- x .pts for x in container .decode (video = 0 ) if x .pts is not None
344
- ]
345
- except av .AVError :
346
- warnings .warn (f"Failed decoding frames for file { filename } " )
347
- video_fps = float (video_stream .average_rate )
348
- container .close ()
349
343
350
344
pts .sort ()
351
345
0 commit comments