Skip to content

Commit 17c6d39

Browse files
authored
Fix syncbn (#32989)
* fix syncbn
1 parent b751a80 commit 17c6d39

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_convert(self):
248248
isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
249249

250250

251-
class TestConvertSyncBatchNormCase2(unittest.TestCase):
251+
class TestConvertSyncBatchNormCast1(unittest.TestCase):
252252
def test_convert(self):
253253
if not core.is_compiled_with_cuda():
254254
return
@@ -277,5 +277,70 @@ def forward(self, x):
277277
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))
278278

279279

280+
class TestConvertSyncBatchNormCase2(unittest.TestCase):
281+
def test_convert(self):
282+
if not core.is_compiled_with_cuda():
283+
return
284+
285+
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
286+
287+
class SyBNNet(paddle.nn.Layer):
288+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
289+
super(SyBNNet, self).__init__()
290+
self.bn_s1 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
291+
paddle.nn.BatchNorm3D(
292+
out_ch,
293+
weight_attr=paddle.ParamAttr(
294+
regularizer=paddle.regularizer.L2Decay(0.))))
295+
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
296+
paddle.nn.BatchNorm3D(
297+
out_ch, data_format='NDHWC'))
298+
299+
def forward(self, x):
300+
x = self.bn_s1(x)
301+
out = paddle.sum(paddle.abs(self.bn_s2(x)))
302+
return out
303+
304+
class BNNet(paddle.nn.Layer):
305+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
306+
super(BNNet, self).__init__()
307+
self.bn_s1 = paddle.nn.BatchNorm3D(
308+
out_ch,
309+
weight_attr=paddle.ParamAttr(
310+
regularizer=paddle.regularizer.L2Decay(0.)))
311+
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
312+
paddle.nn.BatchNorm3D(
313+
out_ch, data_format='NDHWC'))
314+
315+
def forward(self, x):
316+
x = self.bn_s1(x)
317+
out = paddle.sum(paddle.abs(self.bn_s2(x)))
318+
return out
319+
320+
bn_model = BNNet()
321+
sybn_model = SyBNNet()
322+
np.random.seed(10)
323+
data = np.random.random([3, 3, 3, 3, 3]).astype('float32')
324+
x = paddle.to_tensor(data)
325+
bn_out = bn_model(x)
326+
sybn_out = sybn_model(x)
327+
self.assertTrue(
328+
np.allclose(bn_out.numpy(), sybn_out.numpy()),
329+
"Output has diff. \n" + "\nBN " + str(bn_out.numpy()) + "\n"
330+
+ "Sync BN " + str(sybn_out.numpy()))
331+
332+
333+
class TestDygraphSyncBatchNormDataFormatError(unittest.TestCase):
334+
def test_errors(self):
335+
if not core.is_compiled_with_cuda():
336+
return
337+
338+
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
339+
my_sync_batch_norm = paddle.nn.SyncBatchNorm(10, data_format='CN')
340+
data = np.random.random([3, 3, 3]).astype('float32')
341+
x = paddle.to_tensor(data)
342+
self.assertRaises(ValueError, my_sync_batch_norm, x)
343+
344+
280345
if __name__ == '__main__':
281346
unittest.main()

python/paddle/nn/layer/norm.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,18 @@ def __init__(self,
10571057
self).__init__(num_features, momentum, epsilon, weight_attr,
10581058
bias_attr, data_format, None, name)
10591059

1060+
def _check_data_format(self):
1061+
if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']:
1062+
self._data_format = 'NCHW'
1063+
elif self._data_format in ["NHWC", "NDHWC", 'NLC']:
1064+
self._data_format = 'NHWC'
1065+
else:
1066+
raise ValueError(
1067+
'expected \'NCDHW\', \'NDHWC\', \'NCL\', \'NLC\', \'NC\', \'NCHW\', \'NHWC\' for data_format'
1068+
)
1069+
10601070
def forward(self, x):
1071+
self._check_data_format()
10611072
# create output
10621073
# mean and mean_out share the same memory
10631074
mean_out = self._mean
@@ -1142,11 +1153,12 @@ def convert_sync_batchnorm(cls, layer):
11421153
"""
11431154
layer_output = layer
11441155
if isinstance(layer, _BatchNormBase):
1145-
if layer._weight_attr != None and not isinstance(layer._weight_attr,
1146-
bool):
1156+
if layer._weight_attr != None and not isinstance(
1157+
layer._weight_attr,
1158+
bool) and layer._weight_attr.name != None:
11471159
layer._weight_attr.name = layer._weight_attr.name + '_sync'
1148-
if layer._bias_attr != None and not isinstance(layer._weight_attr,
1149-
bool):
1160+
if layer._bias_attr != None and not isinstance(
1161+
layer._bias_attr, bool) and layer._bias_attr.name != None:
11501162
layer._bias_attr.name = layer._bias_attr.name + '_sync'
11511163

11521164
layer_output = SyncBatchNorm(layer._num_features, layer._momentum,

0 commit comments

Comments
 (0)