Skip to content

Commit f584309

Browse files
authored
Fixed audio-video synchronisation problem in read_video() when using pts as unit (#3791)
* Fixed audio-video synchronisation problem in read_video() when using as unit * Addressed review comments * Added unit test
1 parent 154283b commit f584309

File tree

3 files changed

+84
-18
lines changed

3 files changed

+84
-18
lines changed

test/test_video_reader.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,45 @@ def test_read_video_from_memory_scripted(self):
12381238
)
12391239
# FUTURE: check value of video / audio frames
12401240

1241+
def test_audio_video_sync(self):
1242+
"""Test if audio/video are synchronised with pyav output."""
1243+
for test_video, config in test_videos.items():
1244+
full_path = os.path.join(VIDEO_DIR, test_video)
1245+
container = av.open(full_path)
1246+
if not container.streams.audio:
1247+
# Skip if no audio stream
1248+
continue
1249+
start_pts_val, cutoff = 0, 1
1250+
if container.streams.video:
1251+
video = container.streams.video[0]
1252+
arr = []
1253+
for index, frame in enumerate(container.decode(video)):
1254+
if index == cutoff:
1255+
start_pts_val = frame.pts
1256+
if index >= cutoff:
1257+
arr.append(frame.to_rgb().to_ndarray())
1258+
visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
1259+
self.assertAlmostEqual(
1260+
config.video_fps, info['video_fps'], delta=0.0001
1261+
)
1262+
arr = torch.Tensor(arr)
1263+
if arr.shape == visual.shape:
1264+
self.assertGreaterEqual(
1265+
torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99)
1266+
1267+
container = av.open(full_path)
1268+
if container.streams.audio:
1269+
audio = container.streams.audio[0]
1270+
arr = []
1271+
for index, frame in enumerate(container.decode(audio)):
1272+
if index >= cutoff:
1273+
arr.append(frame.to_ndarray())
1274+
_, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
1275+
arr = torch.as_tensor(np.concatenate(arr, axis=1))
1276+
if arr.shape == audio.shape:
1277+
self.assertGreaterEqual(
1278+
torch.mean(torch.isclose(audio.float(), arr).float()), 0.99)
1279+
12411280

12421281
if __name__ == "__main__":
12431282
unittest.main()

torchvision/io/_video_opt.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,14 @@ def _probe_video_from_memory(video_data):
471471
return info
472472

473473

474+
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
475+
if pts_unit == 'pts':
476+
start_pts = float(start_pts * time_base)
477+
end_pts = float(end_pts * time_base)
478+
pts_unit = 'sec'
479+
return start_pts, end_pts, pts_unit
480+
481+
474482
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
475483
if end_pts is None:
476484
end_pts = float("inf")
@@ -485,32 +493,43 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
485493

486494
has_video = info.has_video
487495
has_audio = info.has_audio
496+
video_pts_range = (0, -1)
497+
video_timebase = default_timebase
498+
audio_pts_range = (0, -1)
499+
audio_timebase = default_timebase
500+
time_base = default_timebase
501+
502+
if has_video:
503+
video_timebase = Fraction(
504+
info.video_timebase.numerator, info.video_timebase.denominator
505+
)
506+
time_base = video_timebase
507+
508+
if has_audio:
509+
audio_timebase = Fraction(
510+
info.audio_timebase.numerator, info.audio_timebase.denominator
511+
)
512+
time_base = time_base if time_base else audio_timebase
513+
514+
# video_timebase is the default time_base
515+
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(
516+
start_pts, end_pts, pts_unit, time_base)
488517

489518
def get_pts(time_base):
490-
start_offset = start_pts
491-
end_offset = end_pts
519+
start_offset = start_pts_sec
520+
end_offset = end_pts_sec
492521
if pts_unit == "sec":
493-
start_offset = int(math.floor(start_pts * (1 / time_base)))
522+
start_offset = int(math.floor(start_pts_sec * (1 / time_base)))
494523
if end_offset != float("inf"):
495-
end_offset = int(math.ceil(end_pts * (1 / time_base)))
524+
end_offset = int(math.ceil(end_pts_sec * (1 / time_base)))
496525
if end_offset == float("inf"):
497526
end_offset = -1
498527
return start_offset, end_offset
499528

500-
video_pts_range = (0, -1)
501-
video_timebase = default_timebase
502529
if has_video:
503-
video_timebase = Fraction(
504-
info.video_timebase.numerator, info.video_timebase.denominator
505-
)
506530
video_pts_range = get_pts(video_timebase)
507531

508-
audio_pts_range = (0, -1)
509-
audio_timebase = default_timebase
510532
if has_audio:
511-
audio_timebase = Fraction(
512-
info.audio_timebase.numerator, info.audio_timebase.denominator
513-
)
514533
audio_pts_range = get_pts(audio_timebase)
515534

516535
vframes, aframes, info = _read_video_from_file(

torchvision/io/video.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,19 @@ def read_video(
278278

279279
try:
280280
with av.open(filename, metadata_errors="ignore") as container:
281+
time_base = _video_opt.default_timebase
282+
if container.streams.video:
283+
time_base = container.streams.video[0].time_base
284+
elif container.streams.audio:
285+
time_base = container.streams.audio[0].time_base
286+
# video_timebase is the default time_base
287+
start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec(
288+
start_pts, end_pts, pts_unit, time_base)
281289
if container.streams.video:
282290
video_frames = _read_from_stream(
283291
container,
284-
start_pts,
285-
end_pts,
292+
start_pts_sec,
293+
end_pts_sec,
286294
pts_unit,
287295
container.streams.video[0],
288296
{"video": 0},
@@ -295,8 +303,8 @@ def read_video(
295303
if container.streams.audio:
296304
audio_frames = _read_from_stream(
297305
container,
298-
start_pts,
299-
end_pts,
306+
start_pts_sec,
307+
end_pts_sec,
300308
pts_unit,
301309
container.streams.audio[0],
302310
{"audio": 0},

0 commit comments

Comments
 (0)