Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,11 @@ def test_random_erasing(self):
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
assert torch.equal(img_re, img)

# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
assert torch.equal(img_re, img)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def erase(img, i, j, h, w, v, inplace=False):
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool,optional): For in-place operations. By default is set False.
inplace(bool, optional): For in-place operations. By default is set False.

Returns:
Tensor Image: Erased image.
Expand Down
13 changes: 8 additions & 5 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ class RandomErasing(object):
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace.Default set to False.
inplace: boolean to make this transform inplace. Default set to False.

Returns:
Erased Image.
Expand All @@ -1223,7 +1223,7 @@ class RandomErasing(object):
>>> ])
"""

def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
Expand All @@ -1250,10 +1250,10 @@ def get_params(img, scale, ratio, value=0):
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_b, img_h, img_w = img.shape
img_c, img_h, img_w = img.shape
area = img_h * img_w

while True:
for attempt in range(10):
erase_area = random.uniform(scale[0], scale[1]) * area
aspect_ratio = random.uniform(ratio[0], ratio[1])

Expand All @@ -1266,11 +1266,14 @@ def get_params(img, scale, ratio, value=0):
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.rand(img_b, h, w)
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v

# Return original image
return 0, 0, img_h, img_w, img

def __call__(self, img):
"""
Args:
Expand Down