Skip to content

Commit 2115380

Browse files
surgan12fmassa
authored andcommitted
normalise updates (#699)
* normalise * some changes * Update functional.py * Update functional.py * code changes
1 parent 885e3c2 commit 2115380

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

torchvision/transforms/functional.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ def to_pil_image(pic, mode=None):
181181
return Image.fromarray(npimg, mode=mode)
182182

183183

184-
def normalize(tensor, mean, std):
184+
def normalize(tensor, mean, std, inplace=False):
185185
"""Normalize a tensor image with mean and standard deviation.
186186
187187
.. note::
188-
This transform acts in-place, i.e., it mutates the input tensor.
188+
This transform acts out of place by default, i.e., it does not mutates the input tensor.
189189
190190
See :class:`~torchvision.transforms.Normalize` for more details.
191191
@@ -200,9 +200,12 @@ def normalize(tensor, mean, std):
200200
if not _is_tensor_image(tensor):
201201
raise TypeError('tensor is not a torch image.')
202202

203-
# This is faster than using broadcasting, don't change without benchmarking
204-
for t, m, s in zip(tensor, mean, std):
205-
t.sub_(m).div_(s)
203+
if not inplace:
204+
tensor = tensor.clone()
205+
206+
mean = torch.tensor(mean, dtype=torch.float32)
207+
std = torch.tensor(std, dtype=torch.float32)
208+
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
206209
return tensor
207210

208211

torchvision/transforms/transforms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,17 @@ class Normalize(object):
136136
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
137137
138138
.. note::
139-
This transform acts in-place, i.e., it mutates the input tensor.
139+
This transform acts out of place, i.e., it does not mutates the input tensor.
140140
141141
Args:
142142
mean (sequence): Sequence of means for each channel.
143143
std (sequence): Sequence of standard deviations for each channel.
144144
"""
145145

146-
def __init__(self, mean, std):
146+
def __init__(self, mean, std, inplace=False):
147147
self.mean = mean
148148
self.std = std
149+
self.inplace = inplace
149150

150151
def __call__(self, tensor):
151152
"""
@@ -155,7 +156,7 @@ def __call__(self, tensor):
155156
Returns:
156157
Tensor: Normalized Tensor image.
157158
"""
158-
return F.normalize(tensor, self.mean, self.std)
159+
return F.normalize(tensor, self.mean, self.std, self.inplace)
159160

160161
def __repr__(self):
161162
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

0 commit comments

Comments
 (0)