Skip to content

Commit 1b189fa

Browse files
committed
fix
1 parent 28d8291 commit 1b189fa

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

python/paddle/nn/layer/norm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,18 +1057,18 @@ def __init__(self,
10571057
self).__init__(num_features, momentum, epsilon, weight_attr,
10581058
bias_attr, data_format, None, name)
10591059

1060-
def _check_data_format(self, input):
1061-
if input == 'NCHW' or input == 'NCDHW' or input == 'NC' or input == 'NCL':
1060+
def _check_data_format(self):
1061+
if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']:
10621062
self._data_format = 'NCHW'
1063-
elif input == "NHWC" or input == "NDHWC" or input == 'NLC':
1063+
elif self._data_format in ["NHWC", "NDHWC", 'NLC']:
10641064
self._data_format = 'NHWC'
10651065
else:
10661066
raise ValueError(
1067-
'expected NCDHW, NDHWC, NCL, NLC, NC, NCHW, NHWC or None for data_format input'
1067+
'expected \'NCDHW\', \'NDHWC\', \'NCL\', \'NLC\', \'NC\', \'NCHW\', \'NHWC\' for data_format'
10681068
)
10691069

10701070
def forward(self, x):
1071-
self._check_data_format(self._data_format)
1071+
self._check_data_format()
10721072
# create output
10731073
# mean and mean_out share the same memory
10741074
mean_out = self._mean
@@ -1158,7 +1158,7 @@ def convert_sync_batchnorm(cls, layer):
11581158
bool) and layer._weight_attr.name != None:
11591159
layer._weight_attr.name = layer._weight_attr.name + '_sync'
11601160
if layer._bias_attr != None and not isinstance(
1161-
layer._weight_attr, bool) and layer._bias_attr.name != None:
1161+
layer._bias_attr, bool) and layer._bias_attr.name != None:
11621162
layer._bias_attr.name = layer._bias_attr.name + '_sync'
11631163

11641164
layer_output = SyncBatchNorm(layer._num_features, layer._momentum,

0 commit comments

Comments
 (0)