@@ -248,7 +248,7 @@ def test_convert(self):
248248 isinstance (model [idx ], paddle .nn .SyncBatchNorm ), True )
249249
250250
251- class TestConvertSyncBatchNormCase2 (unittest .TestCase ):
251+ class TestConvertSyncBatchNormCast1 (unittest .TestCase ):
252252 def test_convert (self ):
253253 if not core .is_compiled_with_cuda ():
254254 return
@@ -277,5 +277,70 @@ def forward(self, x):
277277 self .assertEqual (len (compare_model .sublayers ()), len (model .sublayers ()))
278278
279279
280+ class TestConvertSyncBatchNormCase2 (unittest .TestCase ):
281+ def test_convert (self ):
282+ if not core .is_compiled_with_cuda ():
283+ return
284+
285+ with fluid .dygraph .guard (fluid .CUDAPlace (0 )):
286+
287+ class SyBNNet (paddle .nn .Layer ):
288+ def __init__ (self , in_ch = 3 , out_ch = 3 , dirate = 1 ):
289+ super (SyBNNet , self ).__init__ ()
290+ self .bn_s1 = paddle .nn .SyncBatchNorm .convert_sync_batchnorm (
291+ paddle .nn .BatchNorm3D (
292+ out_ch ,
293+ weight_attr = paddle .ParamAttr (
294+ regularizer = paddle .regularizer .L2Decay (0. ))))
295+ self .bn_s2 = paddle .nn .SyncBatchNorm .convert_sync_batchnorm (
296+ paddle .nn .BatchNorm3D (
297+ out_ch , data_format = 'NDHWC' ))
298+
299+ def forward (self , x ):
300+ x = self .bn_s1 (x )
301+ out = paddle .sum (paddle .abs (self .bn_s2 (x )))
302+ return out
303+
304+ class BNNet (paddle .nn .Layer ):
305+ def __init__ (self , in_ch = 3 , out_ch = 3 , dirate = 1 ):
306+ super (BNNet , self ).__init__ ()
307+ self .bn_s1 = paddle .nn .BatchNorm3D (
308+ out_ch ,
309+ weight_attr = paddle .ParamAttr (
310+ regularizer = paddle .regularizer .L2Decay (0. )))
311+ self .bn_s2 = paddle .nn .SyncBatchNorm .convert_sync_batchnorm (
312+ paddle .nn .BatchNorm3D (
313+ out_ch , data_format = 'NDHWC' ))
314+
315+ def forward (self , x ):
316+ x = self .bn_s1 (x )
317+ out = paddle .sum (paddle .abs (self .bn_s2 (x )))
318+ return out
319+
320+ bn_model = BNNet ()
321+ sybn_model = SyBNNet ()
322+ np .random .seed (10 )
323+ data = np .random .random ([3 , 3 , 3 , 3 , 3 ]).astype ('float32' )
324+ x = paddle .to_tensor (data )
325+ bn_out = bn_model (x )
326+ sybn_out = sybn_model (x )
327+ self .assertTrue (
328+ np .allclose (bn_out .numpy (), sybn_out .numpy ()),
329+ "Output has diff. \n " + "\n BN " + str (bn_out .numpy ()) + "\n "
330+ + "Sync BN " + str (sybn_out .numpy ()))
331+
332+
333+ class TestDygraphSyncBatchNormDataFormatError (unittest .TestCase ):
334+ def test_errors (self ):
335+ if not core .is_compiled_with_cuda ():
336+ return
337+
338+ with fluid .dygraph .guard (fluid .CUDAPlace (0 )):
339+ my_sync_batch_norm = paddle .nn .SyncBatchNorm (10 , data_format = 'CN' )
340+ data = np .random .random ([3 , 3 , 3 ]).astype ('float32' )
341+ x = paddle .to_tensor (data )
342+ self .assertRaises (ValueError , my_sync_batch_norm , x )
343+
344+
280345if __name__ == '__main__' :
281346 unittest .main ()
0 commit comments