diff --git a/test/test_transforms.py b/test/test_transforms.py index 052ca2d4996..b3d55ae6b57 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -235,6 +235,37 @@ def test_pad_with_tuple_of_pad_values(self): # Checking if Padding can be printed as string transforms.Pad(padding).__repr__() + def test_pad_with_non_constant_padding_modes(self): + """Unit tests for edge, reflect, symmetric padding""" + img = torch.zeros(3, 27, 27) + img[:, :, 0] = 1 # Constant value added to leftmost edge + img = transforms.ToPILImage()(img) + img = F.pad(img, 1, (200, 200, 200)) + + # pad 3 to all sidess + edge_padded_img = F.pad(img, 3, padding_mode='edge') + # First 6 elements of leftmost edge in the middle of the image, values are in order: + # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 + edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] + assert np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 255, 0])) + assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) + + # Pad 3 to left/right, 2 to top/bottom + reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect') + # First 6 elements of leftmost edge in the middle of the image, values are in order: + # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 + reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] + assert np.all(reflect_middle_slice == np.asarray([0, 0, 255, 200, 255, 0])) + assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) + + # Pad 3 to left, 2 to top, 2 to right, 1 to bottom + symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric') + # First 6 elements of leftmost edge in the middle of the image, values are in order: + # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 + symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] + assert np.all(symmetric_middle_slice == np.asarray([0, 255, 200, 200, 255, 0])) + assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34) + def test_pad_raises_with_invalid_pad_sequence_len(self): with self.assertRaises(ValueError): transforms.Pad(()) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index db2f5e60dc7..5d5325078be 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -210,8 +210,8 @@ def scale(*args, **kwargs): return resize(*args, **kwargs) -def pad(img, padding, fill=0): - """Pad the given PIL Image on all sides with the given "pad" value. +def pad(img, padding, fill=0, padding_mode='constant'): + """Pad the given PIL Image on all sides with speficified padding mode and fill value. Args: img (PIL Image): Image to be padded. @@ -220,8 +220,18 @@ def pad(img, padding, fill=0): on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. - fill: Pixel fill value. Default is 0. If a tuple of + fill: Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + constant: pads with a constant value, this value is specified with fill + edge: pads with the last value on the edge of the image + reflect: pads with reflection of image (without repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + symmetric: pads with reflection of image (repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] Returns: PIL Image: Padded image. @@ -233,12 +243,39 @@ def pad(img, padding, fill=0): raise TypeError('Got inappropriate padding arg') if not isinstance(fill, (numbers.Number, str, tuple)): raise TypeError('Got inappropriate fill arg') + if not isinstance(padding_mode, str): + raise TypeError('Got inappropriate padding_mode arg') if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - return ImageOps.expand(img, border=padding, fill=fill) + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ + 'Padding mode should be either constant, edge, reflect or symmetric' + + if padding_mode == 'constant': + return ImageOps.expand(img, border=padding, fill=fill) + else: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, collections.Sequence) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, collections.Sequence) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + + return Image.fromarray(img) def crop(img, i, j, h, w): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 853d7b5b10d..8385b04df10 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -227,19 +227,31 @@ class Pad(object): on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. - fill: Pixel fill value. Default is 0. If a tuple of + fill: Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + constant: pads with a constant value, this value is specified with fill + edge: pads with the last value at the edge of the image + reflect: pads with reflection of image (without repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + symmetric: pads with reflection of image (repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] """ - def __init__(self, padding, fill=0): + def __init__(self, padding, fill=0, padding_mode='constant'): assert isinstance(padding, (numbers.Number, tuple)) assert isinstance(fill, (numbers.Number, str, tuple)) + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) self.padding = padding self.fill = fill + self.padding_mode = padding_mode def __call__(self, img): """ @@ -249,10 +261,11 @@ def __call__(self, img): Returns: PIL Image: Padded image. """ - return F.pad(img, self.padding, self.fill) + return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self): - return self.__class__.__name__ + '(padding={0}, fill={1})'.format(self.padding, self.fill) + return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ + format(self.padding, self.fill, self.padding_mode) class Lambda(object):