Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def scale(*args, **kwargs):
return resize(*args, **kwargs)


def pad(img, padding, fill=0):
def pad(img, padding, fill=0, padding_mode='constant'):
"""Pad the given PIL Image on all sides with the given "pad" value.

Args:
Expand All @@ -220,9 +220,10 @@ def pad(img, padding, fill=0):
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill: Pixel fill value. Default is 0. If a tuple of
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.

padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric.

This comment was marked as off-topic.

This comment was marked as off-topic.

default value is constant fill.
Returns:
PIL Image: Padded image.
"""
Expand All @@ -233,12 +234,39 @@ def pad(img, padding, fill=0):
raise TypeError('Got inappropriate padding arg')
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError('Got inappropriate fill arg')
if not isinstance(padding_mode, str):
raise TypeError('Got inappropriate padding_mode arg')

if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

return ImageOps.expand(img, border=padding, fill=fill)
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
'Padding mode should be either constant, edge, reflect or symmetric'

if padding_mode == 'constant':
return ImageOps.expand(img, border=padding, fill=fill)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, collections.Sequence) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, collections.Sequence) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]

img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)

return Image.fromarray(img)


def crop(img, i, j, h, w):
Expand Down
9 changes: 6 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,17 @@ class Pad(object):
length 3, it is used to fill R, G, B channels respectively.

This comment was marked as off-topic.

This comment was marked as off-topic.

"""

def __init__(self, padding, fill=0):
def __init__(self, padding, fill=0, padding_mode='constant'):
assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

self.padding = padding
self.fill = fill
self.padding_mode = padding_mode

def __call__(self, img):
"""
Expand All @@ -249,10 +251,11 @@ def __call__(self, img):
Returns:
PIL Image: Padded image.
"""
return F.pad(img, self.padding, self.fill)
return F.pad(img, self.padding, self.fill, self.padding_mode)

def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1})'.format(self.padding, self.fill)
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
format(self.padding, self.fill, self.padding_mode)


class Lambda(object):
Expand Down