Skip to content

Commit 42aa9b2

Browse files
authored
Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity (#2335)
* Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity. * Fix flake8 warning in video_utils.py
1 parent 32f21da commit 42aa9b2

File tree

3 files changed

+72
-88
lines changed

3 files changed

+72
-88
lines changed

torchvision/datasets/utils.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,9 @@ def list_dir(root, prefix=False):
9595
only returns the name of the directories found
9696
"""
9797
root = os.path.expanduser(root)
98-
directories = list(
99-
filter(
100-
lambda p: os.path.isdir(os.path.join(root, p)),
101-
os.listdir(root)
102-
)
103-
)
104-
98+
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
10599
if prefix is True:
106100
directories = [os.path.join(root, d) for d in directories]
107-
108101
return directories
109102

110103

@@ -119,16 +112,9 @@ def list_files(root, suffix, prefix=False):
119112
only returns the name of the files found
120113
"""
121114
root = os.path.expanduser(root)
122-
files = list(
123-
filter(
124-
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
125-
os.listdir(root)
126-
)
127-
)
128-
115+
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
129116
if prefix is True:
130117
files = [os.path.join(root, d) for d in files]
131-
132118
return files
133119

134120

torchvision/datasets/video_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import bisect
22
import math
33
from fractions import Fraction
4+
from typing import List
45

56
import torch
67
from torchvision.io import (
@@ -45,20 +46,23 @@ def unfold(tensor, size, step, dilation=1):
4546
return torch.as_strided(tensor, new_size, new_stride)
4647

4748

48-
class _DummyDataset(object):
49+
class _VideoTimestampsDataset(object):
4950
"""
50-
Dummy dataset used for DataLoader in VideoClips.
51-
Defined at top level so it can be pickled when forking.
51+
Dataset used to parallelize the reading of the timestamps
52+
of a list of videos, given their paths in the filesystem.
53+
54+
Used in VideoClips and defined at top level so it can be
55+
pickled when forking.
5256
"""
5357

54-
def __init__(self, x):
55-
self.x = x
58+
def __init__(self, video_paths: List[str]):
59+
self.video_paths = video_paths
5660

5761
def __len__(self):
58-
return len(self.x)
62+
return len(self.video_paths)
5963

6064
def __getitem__(self, idx):
61-
return read_video_timestamps(self.x[idx])
65+
return read_video_timestamps(self.video_paths[idx])
6266

6367

6468
class VideoClips(object):
@@ -132,7 +136,7 @@ def _compute_frame_pts(self):
132136
import torch.utils.data
133137

134138
dl = torch.utils.data.DataLoader(
135-
_DummyDataset(self.video_paths),
139+
_VideoTimestampsDataset(self.video_paths),
136140
batch_size=16,
137141
num_workers=self.num_workers,
138142
collate_fn=self._collate_fn,

torchvision/io/video.py

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,23 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
7070
if isinstance(fps, float):
7171
fps = np.round(fps)
7272

73-
container = av.open(filename, mode="w")
74-
75-
stream = container.add_stream(video_codec, rate=fps)
76-
stream.width = video_array.shape[2]
77-
stream.height = video_array.shape[1]
78-
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
79-
stream.options = options or {}
80-
81-
for img in video_array:
82-
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
83-
frame.pict_type = "NONE"
84-
for packet in stream.encode(frame):
73+
with av.open(filename, mode="w") as container:
74+
stream = container.add_stream(video_codec, rate=fps)
75+
stream.width = video_array.shape[2]
76+
stream.height = video_array.shape[1]
77+
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
78+
stream.options = options or {}
79+
80+
for img in video_array:
81+
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
82+
frame.pict_type = "NONE"
83+
for packet in stream.encode(frame):
84+
container.mux(packet)
85+
86+
# Flush stream
87+
for packet in stream.encode():
8588
container.mux(packet)
8689

87-
# Flush stream
88-
for packet in stream.encode():
89-
container.mux(packet)
90-
91-
# Close the file
92-
container.close()
93-
9490

9591
def _read_from_stream(
9692
container, start_offset, end_offset, pts_unit, stream, stream_name
@@ -234,37 +230,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
234230
audio_frames = []
235231

236232
try:
237-
container = av.open(filename, metadata_errors="ignore")
233+
with av.open(filename, metadata_errors="ignore") as container:
234+
if container.streams.video:
235+
video_frames = _read_from_stream(
236+
container,
237+
start_pts,
238+
end_pts,
239+
pts_unit,
240+
container.streams.video[0],
241+
{"video": 0},
242+
)
243+
video_fps = container.streams.video[0].average_rate
244+
# guard against potentially corrupted files
245+
if video_fps is not None:
246+
info["video_fps"] = float(video_fps)
247+
248+
if container.streams.audio:
249+
audio_frames = _read_from_stream(
250+
container,
251+
start_pts,
252+
end_pts,
253+
pts_unit,
254+
container.streams.audio[0],
255+
{"audio": 0},
256+
)
257+
info["audio_fps"] = container.streams.audio[0].rate
258+
238259
except av.AVError:
239260
# TODO raise a warning?
240261
pass
241-
else:
242-
if container.streams.video:
243-
video_frames = _read_from_stream(
244-
container,
245-
start_pts,
246-
end_pts,
247-
pts_unit,
248-
container.streams.video[0],
249-
{"video": 0},
250-
)
251-
video_fps = container.streams.video[0].average_rate
252-
# guard against potentially corrupted files
253-
if video_fps is not None:
254-
info["video_fps"] = float(video_fps)
255-
256-
if container.streams.audio:
257-
audio_frames = _read_from_stream(
258-
container,
259-
start_pts,
260-
end_pts,
261-
pts_unit,
262-
container.streams.audio[0],
263-
{"audio": 0},
264-
)
265-
info["audio_fps"] = container.streams.audio[0].rate
266-
267-
container.close()
268262

269263
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
270264
aframes = [frame.to_ndarray() for frame in audio_frames]
@@ -293,6 +287,14 @@ def _can_read_timestamps_from_packets(container):
293287
return False
294288

295289

290+
def _decode_video_timestamps(container):
291+
if _can_read_timestamps_from_packets(container):
292+
# fast path
293+
return [x.pts for x in container.demux(video=0) if x.pts is not None]
294+
else:
295+
return [x.pts for x in container.decode(video=0) if x.pts is not None]
296+
297+
296298
def read_video_timestamps(filename, pts_unit="pts"):
297299
"""
298300
List the video frames timestamps.
@@ -326,26 +328,18 @@ def read_video_timestamps(filename, pts_unit="pts"):
326328
pts = []
327329

328330
try:
329-
container = av.open(filename, metadata_errors="ignore")
331+
with av.open(filename, metadata_errors="ignore") as container:
332+
if container.streams.video:
333+
video_stream = container.streams.video[0]
334+
video_time_base = video_stream.time_base
335+
try:
336+
pts = _decode_video_timestamps(container)
337+
except av.AVError:
338+
warnings.warn(f"Failed decoding frames for file {filename}")
339+
video_fps = float(video_stream.average_rate)
330340
except av.AVError:
331341
# TODO add a warning
332342
pass
333-
else:
334-
if container.streams.video:
335-
video_stream = container.streams.video[0]
336-
video_time_base = video_stream.time_base
337-
try:
338-
if _can_read_timestamps_from_packets(container):
339-
# fast path
340-
pts = [x.pts for x in container.demux(video=0) if x.pts is not None]
341-
else:
342-
pts = [
343-
x.pts for x in container.decode(video=0) if x.pts is not None
344-
]
345-
except av.AVError:
346-
warnings.warn(f"Failed decoding frames for file {filename}")
347-
video_fps = float(video_stream.average_rate)
348-
container.close()
349343

350344
pts.sort()
351345

0 commit comments

Comments
 (0)