Skip to content

Commit a796e70

Browse files
committed
put back read_video on pyav backend
1 parent 27c4975 commit a796e70

File tree

3 files changed

+194
-44
lines changed

3 files changed

+194
-44
lines changed

torchvision/datasets/video_utils.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Callable, cast, Optional, TypeVar, Union
66

77
import torch
8-
from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps
8+
from torchvision.io import read_video, read_video_timestamps
99

1010
from .utils import tqdm
1111

@@ -305,11 +305,7 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]
305305
video_path = self.video_paths[video_idx]
306306
clip_pts = self.clips[video_idx][clip_idx]
307307

308-
from torchvision import get_video_backend
309-
310-
backend = get_video_backend()
311-
312-
if backend == "pyav":
308+
if True:
313309
# check for invalid options
314310
if self._video_width != 0:
315311
raise ValueError("pyav backend doesn't support _video_width != 0")
@@ -322,43 +318,10 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]
322318
if self._audio_samples != 0:
323319
raise ValueError("pyav backend doesn't support _audio_samples != 0")
324320

325-
if backend == "pyav":
321+
if True:
326322
start_pts = clip_pts[0].item()
327323
end_pts = clip_pts[-1].item()
328324
video, audio, info = read_video(video_path, start_pts, end_pts)
329-
else:
330-
_info = _probe_video_from_file(video_path)
331-
video_fps = _info.video_fps
332-
audio_fps = None
333-
334-
video_start_pts = cast(int, clip_pts[0].item())
335-
video_end_pts = cast(int, clip_pts[-1].item())
336-
337-
audio_start_pts, audio_end_pts = 0, -1
338-
audio_timebase = Fraction(0, 1)
339-
video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
340-
if _info.has_audio:
341-
audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
342-
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
343-
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
344-
audio_fps = _info.audio_sample_rate
345-
video, audio, _ = _read_video_from_file(
346-
video_path,
347-
video_width=self._video_width,
348-
video_height=self._video_height,
349-
video_min_dimension=self._video_min_dimension,
350-
video_max_dimension=self._video_max_dimension,
351-
video_pts_range=(video_start_pts, video_end_pts),
352-
video_timebase=video_timebase,
353-
audio_samples=self._audio_samples,
354-
audio_channels=self._audio_channels,
355-
audio_pts_range=(audio_start_pts, audio_end_pts),
356-
audio_timebase=audio_timebase,
357-
)
358-
359-
info = {"video_fps": video_fps}
360-
if audio_fps is not None:
361-
info["audio_fps"] = audio_fps
362325

363326
if self.frame_rate is not None:
364327
resampling_idx = self.resampling_idxs[video_idx][clip_idx]

torchvision/io/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
write_jpeg,
1616
write_png,
1717
)
18-
from .video import write_video
18+
from .video import write_video, read_video
1919

2020

2121
__all__ = [
2222
"write_video",
23+
"read_video",
2324
"ImageReadMode",
2425
"decode_image",
2526
"decode_jpeg",

torchvision/io/video.py

Lines changed: 189 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ def _check_av_available() -> None:
4545
raise av
4646

4747

48-
49-
50-
5148
def write_video(
5249
filename: str,
5350
video_array: torch.Tensor,
@@ -168,3 +165,192 @@ def write_video(
168165
for packet in stream.encode():
169166
container.mux(packet)
170167

168+
169+
def read_video(
170+
filename: str,
171+
start_pts: Union[float, Fraction] = 0,
172+
end_pts: Optional[Union[float, Fraction]] = None,
173+
pts_unit: str = "pts",
174+
output_format: str = "THWC",
175+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
176+
"""[DEPRECATED] Reads a video from a file, returning both the video frames and the audio frames
177+
178+
.. warning::
179+
180+
DEPRECATED: All the video decoding and encoding capabilities of torchvision
181+
are deprecated from version 0.22 and will be removed in version 0.24. We
182+
recommend that you migrate to
183+
`TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll
184+
consolidate the future decoding/encoding capabilities of PyTorch
185+
186+
Args:
187+
filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts.
188+
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
189+
The start presentation time of the video
190+
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
191+
The end presentation time
192+
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
193+
either 'pts' or 'sec'. Defaults to 'pts'.
194+
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
195+
196+
Returns:
197+
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
198+
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
199+
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
200+
"""
201+
_raise_video_deprecation_warning()
202+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
203+
_log_api_usage_once(read_video)
204+
205+
output_format = output_format.upper()
206+
if output_format not in ("THWC", "TCHW"):
207+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
208+
209+
from torchvision import get_video_backend
210+
211+
if True: # ignore, this is to avoid a bigger diff in https://github.com/pytorch/vision/pull/9189
212+
_check_av_available()
213+
214+
if end_pts is None:
215+
end_pts = float("inf")
216+
217+
if end_pts < start_pts:
218+
raise ValueError(
219+
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
220+
)
221+
222+
info = {}
223+
video_frames = []
224+
audio_frames = []
225+
audio_timebase = _video_opt.default_timebase
226+
227+
try:
228+
with av.open(filename, metadata_errors="ignore") as container:
229+
if container.streams.audio:
230+
audio_timebase = container.streams.audio[0].time_base
231+
if container.streams.video:
232+
video_frames = _read_from_stream(
233+
container,
234+
start_pts,
235+
end_pts,
236+
pts_unit,
237+
container.streams.video[0],
238+
{"video": 0},
239+
)
240+
video_fps = container.streams.video[0].average_rate
241+
# guard against potentially corrupted files
242+
if video_fps is not None:
243+
info["video_fps"] = float(video_fps)
244+
245+
if container.streams.audio:
246+
audio_frames = _read_from_stream(
247+
container,
248+
start_pts,
249+
end_pts,
250+
pts_unit,
251+
container.streams.audio[0],
252+
{"audio": 0},
253+
)
254+
info["audio_fps"] = container.streams.audio[0].rate
255+
256+
except FFmpegError:
257+
# TODO raise a warning?
258+
pass
259+
260+
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
261+
aframes_list = [frame.to_ndarray() for frame in audio_frames]
262+
263+
if vframes_list:
264+
vframes = torch.as_tensor(np.stack(vframes_list))
265+
else:
266+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
267+
268+
if aframes_list:
269+
aframes = np.concatenate(aframes_list, 1)
270+
aframes = torch.as_tensor(aframes)
271+
if pts_unit == "sec":
272+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
273+
if end_pts != float("inf"):
274+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
275+
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
276+
else:
277+
aframes = torch.empty((1, 0), dtype=torch.float32)
278+
279+
if output_format == "TCHW":
280+
# [T,H,W,C] --> [T,C,H,W]
281+
vframes = vframes.permute(0, 3, 1, 2)
282+
283+
return vframes, aframes, info
284+
285+
286+
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
287+
extradata = container.streams[0].codec_context.extradata
288+
if extradata is None:
289+
return False
290+
if b"Lavc" in extradata:
291+
return True
292+
return False
293+
294+
295+
def _decode_video_timestamps(container: "av.container.Container") -> list[int]:
296+
if _can_read_timestamps_from_packets(container):
297+
# fast path
298+
return [x.pts for x in container.demux(video=0) if x.pts is not None]
299+
else:
300+
return [x.pts for x in container.decode(video=0) if x.pts is not None]
301+
302+
303+
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> tuple[list[int], Optional[float]]:
304+
"""[DEPREACTED] List the video frames timestamps.
305+
306+
.. warning::
307+
308+
DEPRECATED: All the video decoding and encoding capabilities of torchvision
309+
are deprecated from version 0.22 and will be removed in version 0.25. We
310+
recommend that you migrate to
311+
`TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll
312+
consolidate the future decoding/encoding capabilities of PyTorch
313+
314+
Note that the function decodes the whole video frame-by-frame.
315+
316+
Args:
317+
filename (str): path to the video file
318+
pts_unit (str, optional): unit in which timestamp values will be returned
319+
either 'pts' or 'sec'. Defaults to 'pts'.
320+
321+
Returns:
322+
pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
323+
presentation timestamps for each one of the frames in the video.
324+
video_fps (float, optional): the frame rate for the video
325+
326+
"""
327+
_raise_video_deprecation_warning()
328+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
329+
_log_api_usage_once(read_video_timestamps)
330+
from torchvision import get_video_backend
331+
332+
_check_av_available()
333+
334+
video_fps = None
335+
pts = []
336+
337+
try:
338+
with av.open(filename, metadata_errors="ignore") as container:
339+
if container.streams.video:
340+
video_stream = container.streams.video[0]
341+
video_time_base = video_stream.time_base
342+
try:
343+
pts = _decode_video_timestamps(container)
344+
except FFmpegError:
345+
warnings.warn(f"Failed decoding frames for file {filename}")
346+
video_fps = float(video_stream.average_rate)
347+
except FFmpegError as e:
348+
msg = f"Failed to open container for {filename}; Caught error: {e}"
349+
warnings.warn(msg, RuntimeWarning)
350+
351+
pts.sort()
352+
353+
if pts_unit == "sec":
354+
pts = [x * video_time_base for x in pts]
355+
356+
return pts, video_fps

0 commit comments

Comments
 (0)