From 20b0e4f950d2e6c762306bbcb7206bd6c9623e0c Mon Sep 17 00:00:00 2001 From: haaris Date: Thu, 23 May 2024 15:05:49 -0700 Subject: [PATCH 1/3] Remove squeeze on `visibility` in `draw_keypoints` --- torchvision/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 94b3ec65c87..1d1b9e187ab 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -392,10 +392,6 @@ def draw_keypoints( # validate visibility if visibility is None: # set default visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool) - # If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction - # model, make sure visibility has shape (num_instances, K). - # Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place. - visibility = visibility.squeeze(-1) if visibility.ndim != 2: raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}") if visibility.shape != keypoints.shape[:-1]: From 52d4874d0ad6afd612d189c2a5ce02d94f59a2cf Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 28 May 2024 10:32:35 +0100 Subject: [PATCH 2/3] Only check if ndim == 3 --- torchvision/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/utils.py b/torchvision/utils.py index 1d1b9e187ab..6b2d19ec3dd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -392,6 +392,10 @@ def draw_keypoints( # validate visibility if visibility is None: # set default visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool) + if visibility.ndim == 3: + # If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1). + # We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice. + visibility = visibility.squeeze(-1) if visibility.ndim != 2: raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}") if visibility.shape != keypoints.shape[:-1]: From f7896514c07f380baa673a165f488bd5b407780d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 28 May 2024 10:36:35 +0100 Subject: [PATCH 3/3] Add test --- test/test_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index ac394b51d63..e89bef4a6d9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla(): assert_equal(img, img_cp) +def test_draw_keypoins_K_equals_one(): + # Non-regression test for https://github.com/pytorch/vision/pull/8439 + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + keypoints = torch.tensor([[[10, 10]]], dtype=torch.float) + utils.draw_keypoints(img, keypoints) + + @pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable