Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,8 @@ def to_tensor(pic):
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)

img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
Expand Down Expand Up @@ -696,7 +690,7 @@ def adjust_gamma(img, gamma, gain=1):
return img


def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
"""Rotate the image by angle.


Expand All @@ -713,18 +707,36 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
center (2-tuple, optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.

.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

"""
def verify_fill(fill, num_bands):
if PILLOW_VERSION < "5.2.0":
msg = ("The option to fill background area of the rotated image, "
"requires pillow>=5.2.0")
raise RuntimeError(msg)

if fill is None:
fill = 0

if isinstance(fill, (int, float)):
return tuple([fill] * num_bands)
else:
if len(fill) == num_bands:
return fill

msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if isinstance(fill, int):
fill = tuple([fill] * 3)
fill = verify_fill(fill, len(img.getbands()))

return img.rotate(angle, resample, expand, center, fillcolor=fill)

Expand Down