@@ -27,12 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
2727 raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
2828
2929
30- def smoke_test_torchvision_resnet50_classify () -> None :
31- img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" ))
30+ def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
31+ img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" )). to ( device )
3232
3333 # Step 1: Initialize model with the best available weights
3434 weights = ResNet50_Weights .DEFAULT
35- model = resnet50 (weights = weights )
35+ model = resnet50 (weights = weights ). to ( device )
3636 model .eval ()
3737
3838 # Step 2: Initialize the inference transforms
@@ -47,7 +47,7 @@ def smoke_test_torchvision_resnet50_classify() -> None:
4747 score = prediction [class_id ].item ()
4848 category_name = weights .meta ["categories" ][class_id ]
4949 expected_category = "German shepherd"
50- print (f"{ category_name } : { 100 * score :.1f} %" )
50+ print (f"{ category_name } ( { device } ) : { 100 * score :.1f} %" )
5151 if category_name != expected_category :
5252 raise RuntimeError (f"Failed ResNet50 classify { category_name } Expected: { expected_category } " )
5353
@@ -57,6 +57,8 @@ def main() -> None:
5757 smoke_test_torchvision ()
5858 smoke_test_torchvision_read_decode ()
5959 smoke_test_torchvision_resnet50_classify ()
60+ if torch .cuda .is_available ():
61+ smoke_test_torchvision_resnet50_classify ("cuda" )
6062
6163
6264if __name__ == "__main__" :
0 commit comments