Skip to content

Commit 4672c1a

Browse files
authored
Merge branch 'main' into xpu_build
2 parents 45709bb + 3e60dbd commit 4672c1a

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ versions.
2121
| `torch` | `torchvision` | Python |
2222
| ------------------ | ------------------ | ------------------- |
2323
| `main` / `nightly` | `main` / `nightly` | `>=3.8`, `<=3.12` |
24+
| `2.4` | `0.19` | `>=3.8`, `<=3.12` |
2425
| `2.3` | `0.18` | `>=3.8`, `<=3.12` |
2526
| `2.2` | `0.17` | `>=3.8`, `<=3.11` |
2627
| `2.1` | `0.16` | `>=3.8`, `<=3.11` |

test/test_io.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,14 @@ def test_read_video_partially_corrupted_file(self):
255255
assert_equal(video, data)
256256

257257
@pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows")
258-
def test_write_video_with_audio(self, tmpdir):
258+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
259+
def test_write_video_with_audio(self, device, tmpdir):
259260
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
260261
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
261262

263+
video_tensor = video_tensor.to(device)
264+
audio_tensor = audio_tensor.to(device)
265+
262266
out_f_name = os.path.join(tmpdir, "testing.mp4")
263267
io.video.write_video(
264268
out_f_name,

torchvision/io/video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def write_video(
8080
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
8181
_log_api_usage_once(write_video)
8282
_check_av_available()
83-
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
83+
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)
8484

8585
# PyAV does not support floating point numbers with decimal point
8686
# and will throw OverflowException in case this is not the case

0 commit comments

Comments
 (0)