Skip to content

Commit ee91b35

Browse files
committed
Put back more stuff (fix)
1 parent e2eade2 commit ee91b35

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

torchvision/io/video.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def _check_av_available() -> None:
4747
def _av_available() -> bool:
4848
return not isinstance(av, Exception)
4949

50+
# PyAV has some reference cycles
51+
_CALLED_TIMES = 0
52+
_GC_COLLECTION_INTERVAL = 10
53+
54+
5055
def write_video(
5156
filename: str,
5257
video_array: torch.Tensor,
@@ -167,6 +172,85 @@ def write_video(
167172
for packet in stream.encode():
168173
container.mux(packet)
169174

175+
def _read_from_stream(
176+
container: "av.container.Container",
177+
start_offset: float,
178+
end_offset: float,
179+
pts_unit: str,
180+
stream: "av.stream.Stream",
181+
stream_name: dict[str, Optional[Union[int, tuple[int, ...], list[int]]]],
182+
) -> list["av.frame.Frame"]:
183+
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
184+
_CALLED_TIMES += 1
185+
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
186+
gc.collect()
187+
188+
if pts_unit == "sec":
189+
# TODO: we should change all of this from ground up to simply take
190+
# sec and convert to MS in C++
191+
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
192+
if end_offset != float("inf"):
193+
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
194+
else:
195+
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
196+
197+
frames = {}
198+
should_buffer = True
199+
max_buffer_size = 5
200+
if stream.type == "video":
201+
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
202+
# so need to buffer some extra frames to sort everything
203+
# properly
204+
extradata = stream.codec_context.extradata
205+
# overly complicated way of finding if `divx_packed` is set, following
206+
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
207+
if extradata and b"DivX" in extradata:
208+
# can't use regex directly because of some weird characters sometimes...
209+
pos = extradata.find(b"DivX")
210+
d = extradata[pos:]
211+
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
212+
if o is None:
213+
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
214+
if o is not None:
215+
should_buffer = o.group(3) == b"p"
216+
seek_offset = start_offset
217+
# some files don't seek to the right location, so better be safe here
218+
seek_offset = max(seek_offset - 1, 0)
219+
if should_buffer:
220+
# FIXME this is kind of a hack, but we will jump to the previous keyframe
221+
# so this will be safe
222+
seek_offset = max(seek_offset - max_buffer_size, 0)
223+
try:
224+
# TODO check if stream needs to always be the video stream here or not
225+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
226+
except FFmpegError:
227+
# TODO add some warnings in this case
228+
# print("Corrupted file?", container.name)
229+
return []
230+
buffer_count = 0
231+
try:
232+
for _idx, frame in enumerate(container.decode(**stream_name)):
233+
frames[frame.pts] = frame
234+
if frame.pts >= end_offset:
235+
if should_buffer and buffer_count < max_buffer_size:
236+
buffer_count += 1
237+
continue
238+
break
239+
except FFmpegError:
240+
# TODO add a warning
241+
pass
242+
# ensure that the results are sorted wrt the pts
243+
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
244+
if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
245+
# if there is no frame that exactly matches the pts of start_offset
246+
# add the last frame smaller than start_offset, to guarantee that
247+
# we will have all the necessary data. This is most useful for audio
248+
preceding_frames = [i for i in frames if i < start_offset]
249+
if len(preceding_frames) > 0:
250+
first_frame_pts = max(preceding_frames)
251+
result.insert(0, frames[first_frame_pts])
252+
return result
253+
170254

171255
def read_video(
172256
filename: str,

0 commit comments

Comments
 (0)