@@ -210,14 +210,11 @@ def test_draw_segmentation_masks(colors, alpha, device):
210210 num_masks , h , w = 2 , 100 , 100
211211 dtype = torch .uint8
212212 img = torch .randint (0 , 256 , size = (3 , h , w ), dtype = dtype , device = device )
213- masks = torch .randint (0 , 2 , (num_masks , h , w ), dtype = torch .bool , device = device )
213+ masks = torch .zeros ((num_masks , h , w ), dtype = torch .bool , device = device )
214+ masks [0 , 10 :20 , 10 :20 ] = True
215+ masks [1 , 15 :25 , 15 :25 ] = True
214216
215- # For testing we enforce that there's no overlap between the masks. The
216- # current behaviour is that the last mask's color will take priority when
217- # masks overlap, but this makes testing slightly harder, so we don't really
218- # care
219217 overlap = masks [0 ] & masks [1 ]
220- masks [:, overlap ] = False
221218
222219 out = utils .draw_segmentation_masks (img , masks , colors = colors , alpha = alpha )
223220 assert out .dtype == dtype
@@ -239,12 +236,15 @@ def test_draw_segmentation_masks(colors, alpha, device):
239236 color = torch .tensor (color , dtype = dtype , device = device )
240237
241238 if alpha == 1 :
242- assert (out [:, mask ] == color [:, None ]).all ()
239+ assert (out [:, mask & ~ overlap ] == color [:, None ]).all ()
243240 elif alpha == 0 :
244- assert (out [:, mask ] == img [:, mask ]).all ()
241+ assert (out [:, mask & ~ overlap ] == img [:, mask & ~ overlap ]).all ()
245242
246- interpolated_color = (img [:, mask ] * (1 - alpha ) + color [:, None ] * alpha ).to (dtype )
247- torch .testing .assert_close (out [:, mask ], interpolated_color , rtol = 0.0 , atol = 1.0 )
243+ interpolated_color = (img [:, mask & ~ overlap ] * (1 - alpha ) + color [:, None ] * alpha ).to (dtype )
244+ torch .testing .assert_close (out [:, mask & ~ overlap ], interpolated_color , rtol = 0.0 , atol = 1.0 )
245+
246+ interpolated_overlap = (img [:, overlap ] * (1 - alpha )).to (dtype )
247+ torch .testing .assert_close (out [:, overlap ], interpolated_overlap , rtol = 0.0 , atol = 1.0 )
248248
249249
250250def test_draw_segmentation_masks_dtypes ():
0 commit comments