diff --git a/test/smoke_test.py b/test/smoke_test.py index f80aba1d19f..9c58add739e 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -17,7 +17,6 @@ def smoke_test_torchvision() -> None: all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]), ) - def smoke_test_torchvision_read_decode() -> None: img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) if img_jpg.ndim != 3 or img_jpg.numel() < 100: @@ -26,13 +25,12 @@ def smoke_test_torchvision_read_decode() -> None: if img_png.ndim != 3 or img_png.numel() < 100: raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") - -def smoke_test_torchvision_resnet50_classify() -> None: - img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")) +def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: + img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) # Step 1: Initialize model with the best available weights weights = ResNet50_Weights.DEFAULT - model = resnet50(weights=weights) + model = resnet50(weights=weights).to(device) model.eval() # Step 2: Initialize the inference transforms @@ -47,17 +45,19 @@ def smoke_test_torchvision_resnet50_classify() -> None: score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] expected_category = "German shepherd" - print(f"{category_name}: {100 * score:.1f}%") + print(f"{category_name} ({device}): {100 * score:.1f}%") if category_name != expected_category: - raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}") - + raise RuntimeError( + f"Failed ResNet50 classify {category_name} Expected: {expected_category}" + ) def main() -> None: print(f"torchvision: {torchvision.__version__}") smoke_test_torchvision() smoke_test_torchvision_read_decode() smoke_test_torchvision_resnet50_classify() - + if torch.cuda.is_available(): + smoke_test_torchvision_resnet50_classify("cuda") if __name__ == "__main__": main()