Skip to content

Commit 8762598

Browse files
authored
fix bug in output format for pyav (#6672) (#6703)
* fix bug in output format for pyav * add read from memory with constructor overload * Revert "add read from memory with constructor overload" This reverts commit 14cbbab. * run ufmt
1 parent dc6d86d commit 8762598

File tree

1 file changed

+67
-65
lines changed

1 file changed

+67
-65
lines changed

torchvision/io/video.py

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -273,72 +273,74 @@ def read_video(
273273
raise RuntimeError(f"File not found: {filename}")
274274

275275
if get_video_backend() != "pyav":
276-
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
277-
278-
_check_av_available()
279-
280-
if end_pts is None:
281-
end_pts = float("inf")
282-
283-
if end_pts < start_pts:
284-
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
285-
286-
info = {}
287-
video_frames = []
288-
audio_frames = []
289-
audio_timebase = _video_opt.default_timebase
290-
291-
try:
292-
with av.open(filename, metadata_errors="ignore") as container:
293-
if container.streams.audio:
294-
audio_timebase = container.streams.audio[0].time_base
295-
if container.streams.video:
296-
video_frames = _read_from_stream(
297-
container,
298-
start_pts,
299-
end_pts,
300-
pts_unit,
301-
container.streams.video[0],
302-
{"video": 0},
303-
)
304-
video_fps = container.streams.video[0].average_rate
305-
# guard against potentially corrupted files
306-
if video_fps is not None:
307-
info["video_fps"] = float(video_fps)
308-
309-
if container.streams.audio:
310-
audio_frames = _read_from_stream(
311-
container,
312-
start_pts,
313-
end_pts,
314-
pts_unit,
315-
container.streams.audio[0],
316-
{"audio": 0},
317-
)
318-
info["audio_fps"] = container.streams.audio[0].rate
319-
320-
except av.AVError:
321-
# TODO raise a warning?
322-
pass
323-
324-
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
325-
aframes_list = [frame.to_ndarray() for frame in audio_frames]
326-
327-
if vframes_list:
328-
vframes = torch.as_tensor(np.stack(vframes_list))
329-
else:
330-
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
331-
332-
if aframes_list:
333-
aframes = np.concatenate(aframes_list, 1)
334-
aframes = torch.as_tensor(aframes)
335-
if pts_unit == "sec":
336-
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
337-
if end_pts != float("inf"):
338-
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
339-
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
276+
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
340277
else:
341-
aframes = torch.empty((1, 0), dtype=torch.float32)
278+
_check_av_available()
279+
280+
if end_pts is None:
281+
end_pts = float("inf")
282+
283+
if end_pts < start_pts:
284+
raise ValueError(
285+
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
286+
)
287+
288+
info = {}
289+
video_frames = []
290+
audio_frames = []
291+
audio_timebase = _video_opt.default_timebase
292+
293+
try:
294+
with av.open(filename, metadata_errors="ignore") as container:
295+
if container.streams.audio:
296+
audio_timebase = container.streams.audio[0].time_base
297+
if container.streams.video:
298+
video_frames = _read_from_stream(
299+
container,
300+
start_pts,
301+
end_pts,
302+
pts_unit,
303+
container.streams.video[0],
304+
{"video": 0},
305+
)
306+
video_fps = container.streams.video[0].average_rate
307+
# guard against potentially corrupted files
308+
if video_fps is not None:
309+
info["video_fps"] = float(video_fps)
310+
311+
if container.streams.audio:
312+
audio_frames = _read_from_stream(
313+
container,
314+
start_pts,
315+
end_pts,
316+
pts_unit,
317+
container.streams.audio[0],
318+
{"audio": 0},
319+
)
320+
info["audio_fps"] = container.streams.audio[0].rate
321+
322+
except av.AVError:
323+
# TODO raise a warning?
324+
pass
325+
326+
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
327+
aframes_list = [frame.to_ndarray() for frame in audio_frames]
328+
329+
if vframes_list:
330+
vframes = torch.as_tensor(np.stack(vframes_list))
331+
else:
332+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
333+
334+
if aframes_list:
335+
aframes = np.concatenate(aframes_list, 1)
336+
aframes = torch.as_tensor(aframes)
337+
if pts_unit == "sec":
338+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
339+
if end_pts != float("inf"):
340+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
341+
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
342+
else:
343+
aframes = torch.empty((1, 0), dtype=torch.float32)
342344

343345
if output_format == "TCHW":
344346
# [T,H,W,C] --> [T,C,H,W]

0 commit comments

Comments
 (0)