Skip to content

Commit 6f0deb9

Browse files
Set masks to zero where masks overlap (#8213)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 660868b commit 6f0deb9

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

test/test_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,11 @@ def test_draw_segmentation_masks(colors, alpha, device):
210210
num_masks, h, w = 2, 100, 100
211211
dtype = torch.uint8
212212
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
213-
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
213+
masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
214+
masks[0, 10:20, 10:20] = True
215+
masks[1, 15:25, 15:25] = True
214216

215-
# For testing we enforce that there's no overlap between the masks. The
216-
# current behaviour is that the last mask's color will take priority when
217-
# masks overlap, but this makes testing slightly harder, so we don't really
218-
# care
219217
overlap = masks[0] & masks[1]
220-
masks[:, overlap] = False
221218

222219
out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
223220
assert out.dtype == dtype
@@ -239,12 +236,15 @@ def test_draw_segmentation_masks(colors, alpha, device):
239236
color = torch.tensor(color, dtype=dtype, device=device)
240237

241238
if alpha == 1:
242-
assert (out[:, mask] == color[:, None]).all()
239+
assert (out[:, mask & ~overlap] == color[:, None]).all()
243240
elif alpha == 0:
244-
assert (out[:, mask] == img[:, mask]).all()
241+
assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()
245242

246-
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
247-
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
243+
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
244+
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)
245+
246+
interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
247+
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)
248248

249249

250250
def test_draw_segmentation_masks_dtypes():

torchvision/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def draw_segmentation_masks(
299299
raise ValueError("The image and the masks must have the same height and width")
300300

301301
num_masks = masks.size()[0]
302+
overlapping_masks = masks.sum(dim=0) > 1
302303

303304
if num_masks == 0:
304305
warnings.warn("masks doesn't contain any mask. No mask was drawn")
@@ -315,6 +316,8 @@ def draw_segmentation_masks(
315316
for mask, color in zip(masks, colors):
316317
img_to_draw[:, mask] = color[:, None]
317318

319+
img_to_draw[:, overlapping_masks] = 0
320+
318321
out = image * (1 - alpha) + img_to_draw * alpha
319322
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
320323
return out.to(original_dtype)

0 commit comments

Comments
 (0)