Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_convert(self):
isinstance(model[idx], paddle.nn.SyncBatchNorm), True)


class TestConvertSyncBatchNormCase2(unittest.TestCase):
class TestConvertSyncBatchNormCast1(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
Expand Down Expand Up @@ -277,5 +277,70 @@ def forward(self, x):
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))


class TestConvertSyncBatchNormCase2(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return

with fluid.dygraph.guard(fluid.CUDAPlace(0)):

class SyBNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(SyBNNet, self).__init__()
self.bn_s1 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.))))
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch, data_format='NDHWC'))

def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out

class BNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(BNNet, self).__init__()
self.bn_s1 = paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.)))
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch, data_format='NDHWC'))

def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out

bn_model = BNNet()
sybn_model = SyBNNet()
np.random.seed(10)
data = np.random.random([3, 3, 3, 3, 3]).astype('float32')
x = paddle.to_tensor(data)
bn_out = bn_model(x)
sybn_out = sybn_model(x)
self.assertTrue(
np.allclose(bn_out.numpy(), sybn_out.numpy()),
"Output has diff. \n" + "\nBN " + str(bn_out.numpy()) + "\n"
+ "Sync BN " + str(sybn_out.numpy()))


class TestDygraphSyncBatchNormDataFormatError(unittest.TestCase):
def test_errors(self):
if not core.is_compiled_with_cuda():
return

with fluid.dygraph.guard(fluid.CUDAPlace(0)):
my_sync_batch_norm = paddle.nn.SyncBatchNorm(10, data_format='CN')
data = np.random.random([3, 3, 3]).astype('float32')
x = paddle.to_tensor(data)
self.assertRaises(ValueError, my_sync_batch_norm, x)


if __name__ == '__main__':
unittest.main()
20 changes: 16 additions & 4 deletions python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,18 @@ def __init__(self,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, None, name)

def _check_data_format(self):
if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']:
self._data_format = 'NCHW'
elif self._data_format in ["NHWC", "NDHWC", 'NLC']:
self._data_format = 'NHWC'
else:
raise ValueError(
'expected \'NCDHW\', \'NDHWC\', \'NCL\', \'NLC\', \'NC\', \'NCHW\', \'NHWC\' for data_format'
)

def forward(self, x):
self._check_data_format()
# create output
# mean and mean_out share the same memory
mean_out = self._mean
Expand Down Expand Up @@ -1142,11 +1153,12 @@ def convert_sync_batchnorm(cls, layer):
"""
layer_output = layer
if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(layer._weight_attr,
bool):
if layer._weight_attr != None and not isinstance(
layer._weight_attr,
bool) and layer._weight_attr.name != None:
layer._weight_attr.name = layer._weight_attr.name + '_sync'
if layer._bias_attr != None and not isinstance(layer._weight_attr,
bool):
if layer._bias_attr != None and not isinstance(
layer._bias_attr, bool) and layer._bias_attr.name != None:
layer._bias_attr.name = layer._bias_attr.name + '_sync'

layer_output = SyncBatchNorm(layer._num_features, layer._momentum,
Expand Down