Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
def normalize(tensor, mean, std, inplace=False):
"""Normalize a tensor image with mean and standard deviation.

.. note::
This transform acts in-place, i.e., it mutates the input tensor.
This transform acts out of place by default, i.e., it does not mutates the input tensor.

See :class:`~torchvision.transforms.Normalize` for more details.

Expand All @@ -200,9 +200,12 @@ def normalize(tensor, mean, std):
if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')

# This is faster than using broadcasting, don't change without benchmarking
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
if not inplace:
tensor = tensor.clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just have a single clone here, to avoid duplicated code?
like

if not inplace:
    tensor = tensor.clone()

mean = ...


mean = torch.tensor(mean, dtype=torch.float32)
std = torch.tensor(std, dtype=torch.float32)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor


Expand Down
7 changes: 4 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,17 @@ class Normalize(object):
``input[channel] = (input[channel] - mean[channel]) / std[channel]``

.. note::
This transform acts in-place, i.e., it mutates the input tensor.
This transform acts out of place, i.e., it does not mutates the input tensor.

Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""

def __init__(self, mean, std):
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace

def __call__(self, tensor):
"""
Expand All @@ -155,7 +156,7 @@ def __call__(self, tensor):
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std)
return F.normalize(tensor, self.mean, self.std, self.inplace)

def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
Expand Down