Skip to content

Commit 3e60dbd

Browse files
bmmtstbNicolasHug
andauthored
Automatically send video to CPU in io.write_video (#8537)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 4a1cb63 commit 3e60dbd

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

test/test_io.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,14 @@ def test_read_video_partially_corrupted_file(self):
255255
assert_equal(video, data)
256256

257257
@pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows")
258-
def test_write_video_with_audio(self, tmpdir):
258+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
259+
def test_write_video_with_audio(self, device, tmpdir):
259260
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
260261
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
261262

263+
video_tensor = video_tensor.to(device)
264+
audio_tensor = audio_tensor.to(device)
265+
262266
out_f_name = os.path.join(tmpdir, "testing.mp4")
263267
io.video.write_video(
264268
out_f_name,

torchvision/io/video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def write_video(
8080
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
8181
_log_api_usage_once(write_video)
8282
_check_av_available()
83-
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
83+
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)
8484

8585
# PyAV does not support floating point numbers with decimal point
8686
# and will throw OverflowException in case this is not the case

0 commit comments

Comments
 (0)