|
6 | 6 | import pytest |
7 | 7 | import torch |
8 | 8 | import torchvision.io as io |
9 | | -from common_utils import assert_equal |
| 9 | +from common_utils import assert_equal, cpu_and_cuda |
10 | 10 | from torchvision import get_video_backend |
11 | 11 |
|
12 | 12 |
|
@@ -255,22 +255,19 @@ def test_read_video_partially_corrupted_file(self): |
255 | 255 | assert_equal(video, data) |
256 | 256 |
|
257 | 257 | @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows") |
258 | | - @pytest.mark.parametrize("device", ["cpu", "cuda"]) |
| 258 | + @pytest.mark.parametrize("device", cpu_and_cuda()) |
259 | 259 | def test_write_video_with_audio(self, device, tmpdir): |
260 | 260 | f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") |
261 | 261 | video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") |
262 | 262 |
|
263 | | - video_tensor = video_tensor.to(device) |
264 | | - audio_tensor = audio_tensor.to(device) |
265 | | - |
266 | 263 | out_f_name = os.path.join(tmpdir, "testing.mp4") |
267 | 264 | io.video.write_video( |
268 | 265 | out_f_name, |
269 | | - video_tensor, |
| 266 | + video_tensor.to(device), |
270 | 267 | round(info["video_fps"]), |
271 | 268 | video_codec="libx264rgb", |
272 | 269 | options={"crf": "0"}, |
273 | | - audio_array=audio_tensor, |
| 270 | + audio_array=audio_tensor.to(device), |
274 | 271 | audio_fps=info["audio_fps"], |
275 | 272 | audio_codec="aac", |
276 | 273 | ) |
|
0 commit comments