@@ -45,9 +45,6 @@ def _check_av_available() -> None:
4545 raise av
4646
4747
48-
49-
50-
5148def write_video (
5249 filename : str ,
5350 video_array : torch .Tensor ,
@@ -168,3 +165,192 @@ def write_video(
168165 for packet in stream .encode ():
169166 container .mux (packet )
170167
168+
169+ def read_video (
170+ filename : str ,
171+ start_pts : Union [float , Fraction ] = 0 ,
172+ end_pts : Optional [Union [float , Fraction ]] = None ,
173+ pts_unit : str = "pts" ,
174+ output_format : str = "THWC" ,
175+ ) -> tuple [torch .Tensor , torch .Tensor , dict [str , Any ]]:
176+ """[DEPRECATED] Reads a video from a file, returning both the video frames and the audio frames
177+
178+ .. warning::
179+
180+ DEPRECATED: All the video decoding and encoding capabilities of torchvision
181+ are deprecated from version 0.22 and will be removed in version 0.24. We
182+ recommend that you migrate to
183+ `TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll
184+ consolidate the future decoding/encoding capabilities of PyTorch
185+
186+ Args:
187+ filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts.
188+ start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
189+ The start presentation time of the video
190+ end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
191+ The end presentation time
192+ pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
193+ either 'pts' or 'sec'. Defaults to 'pts'.
194+ output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
195+
196+ Returns:
197+ vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
198+ aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
199+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
200+ """
201+ _raise_video_deprecation_warning ()
202+ if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
203+ _log_api_usage_once (read_video )
204+
205+ output_format = output_format .upper ()
206+ if output_format not in ("THWC" , "TCHW" ):
207+ raise ValueError (f"output_format should be either 'THWC' or 'TCHW', got { output_format } ." )
208+
209+ from torchvision import get_video_backend
210+
211+ if True : # ignore, this is to avoid a bigger diff in https://github.com/pytorch/vision/pull/9189
212+ _check_av_available ()
213+
214+ if end_pts is None :
215+ end_pts = float ("inf" )
216+
217+ if end_pts < start_pts :
218+ raise ValueError (
219+ f"end_pts should be larger than start_pts, got start_pts={ start_pts } and end_pts={ end_pts } "
220+ )
221+
222+ info = {}
223+ video_frames = []
224+ audio_frames = []
225+ audio_timebase = _video_opt .default_timebase
226+
227+ try :
228+ with av .open (filename , metadata_errors = "ignore" ) as container :
229+ if container .streams .audio :
230+ audio_timebase = container .streams .audio [0 ].time_base
231+ if container .streams .video :
232+ video_frames = _read_from_stream (
233+ container ,
234+ start_pts ,
235+ end_pts ,
236+ pts_unit ,
237+ container .streams .video [0 ],
238+ {"video" : 0 },
239+ )
240+ video_fps = container .streams .video [0 ].average_rate
241+ # guard against potentially corrupted files
242+ if video_fps is not None :
243+ info ["video_fps" ] = float (video_fps )
244+
245+ if container .streams .audio :
246+ audio_frames = _read_from_stream (
247+ container ,
248+ start_pts ,
249+ end_pts ,
250+ pts_unit ,
251+ container .streams .audio [0 ],
252+ {"audio" : 0 },
253+ )
254+ info ["audio_fps" ] = container .streams .audio [0 ].rate
255+
256+ except FFmpegError :
257+ # TODO raise a warning?
258+ pass
259+
260+ vframes_list = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
261+ aframes_list = [frame .to_ndarray () for frame in audio_frames ]
262+
263+ if vframes_list :
264+ vframes = torch .as_tensor (np .stack (vframes_list ))
265+ else :
266+ vframes = torch .empty ((0 , 1 , 1 , 3 ), dtype = torch .uint8 )
267+
268+ if aframes_list :
269+ aframes = np .concatenate (aframes_list , 1 )
270+ aframes = torch .as_tensor (aframes )
271+ if pts_unit == "sec" :
272+ start_pts = int (math .floor (start_pts * (1 / audio_timebase )))
273+ if end_pts != float ("inf" ):
274+ end_pts = int (math .ceil (end_pts * (1 / audio_timebase )))
275+ aframes = _align_audio_frames (aframes , audio_frames , start_pts , end_pts )
276+ else :
277+ aframes = torch .empty ((1 , 0 ), dtype = torch .float32 )
278+
279+ if output_format == "TCHW" :
280+ # [T,H,W,C] --> [T,C,H,W]
281+ vframes = vframes .permute (0 , 3 , 1 , 2 )
282+
283+ return vframes , aframes , info
284+
285+
286+ def _can_read_timestamps_from_packets (container : "av.container.Container" ) -> bool :
287+ extradata = container .streams [0 ].codec_context .extradata
288+ if extradata is None :
289+ return False
290+ if b"Lavc" in extradata :
291+ return True
292+ return False
293+
294+
295+ def _decode_video_timestamps (container : "av.container.Container" ) -> list [int ]:
296+ if _can_read_timestamps_from_packets (container ):
297+ # fast path
298+ return [x .pts for x in container .demux (video = 0 ) if x .pts is not None ]
299+ else :
300+ return [x .pts for x in container .decode (video = 0 ) if x .pts is not None ]
301+
302+
303+ def read_video_timestamps (filename : str , pts_unit : str = "pts" ) -> tuple [list [int ], Optional [float ]]:
304+ """[DEPREACTED] List the video frames timestamps.
305+
306+ .. warning::
307+
308+ DEPRECATED: All the video decoding and encoding capabilities of torchvision
309+ are deprecated from version 0.22 and will be removed in version 0.25. We
310+ recommend that you migrate to
311+ `TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll
312+ consolidate the future decoding/encoding capabilities of PyTorch
313+
314+ Note that the function decodes the whole video frame-by-frame.
315+
316+ Args:
317+ filename (str): path to the video file
318+ pts_unit (str, optional): unit in which timestamp values will be returned
319+ either 'pts' or 'sec'. Defaults to 'pts'.
320+
321+ Returns:
322+ pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
323+ presentation timestamps for each one of the frames in the video.
324+ video_fps (float, optional): the frame rate for the video
325+
326+ """
327+ _raise_video_deprecation_warning ()
328+ if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
329+ _log_api_usage_once (read_video_timestamps )
330+ from torchvision import get_video_backend
331+
332+ _check_av_available ()
333+
334+ video_fps = None
335+ pts = []
336+
337+ try :
338+ with av .open (filename , metadata_errors = "ignore" ) as container :
339+ if container .streams .video :
340+ video_stream = container .streams .video [0 ]
341+ video_time_base = video_stream .time_base
342+ try :
343+ pts = _decode_video_timestamps (container )
344+ except FFmpegError :
345+ warnings .warn (f"Failed decoding frames for file { filename } " )
346+ video_fps = float (video_stream .average_rate )
347+ except FFmpegError as e :
348+ msg = f"Failed to open container for { filename } ; Caught error: { e } "
349+ warnings .warn (msg , RuntimeWarning )
350+
351+ pts .sort ()
352+
353+ if pts_unit == "sec" :
354+ pts = [x * video_time_base for x in pts ]
355+
356+ return pts , video_fps
0 commit comments