File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -255,10 +255,14 @@ def test_read_video_partially_corrupted_file(self):
255255 assert_equal (video , data )
256256
257257 @pytest .mark .skipif (sys .platform == "win32" , reason = "temporarily disabled on Windows" )
258- def test_write_video_with_audio (self , tmpdir ):
258+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
259+ def test_write_video_with_audio (self , device , tmpdir ):
259260 f_name = os .path .join (VIDEO_DIR , "R6llTwEh07w.mp4" )
260261 video_tensor , audio_tensor , info = io .read_video (f_name , pts_unit = "sec" )
261262
263+ video_tensor = video_tensor .to (device )
264+ audio_tensor = audio_tensor .to (device )
265+
262266 out_f_name = os .path .join (tmpdir , "testing.mp4" )
263267 io .video .write_video (
264268 out_f_name ,
Original file line number Diff line number Diff line change @@ -80,7 +80,7 @@ def write_video(
8080 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
8181 _log_api_usage_once (write_video )
8282 _check_av_available ()
83- video_array = torch .as_tensor (video_array , dtype = torch .uint8 ).numpy ()
83+ video_array = torch .as_tensor (video_array , dtype = torch .uint8 ).numpy (force = True )
8484
8585 # PyAV does not support floating point numbers with decimal point
8686 # and will throw OverflowException in case this is not the case
You can’t perform that action at this time.
0 commit comments