@@ -1057,18 +1057,18 @@ def __init__(self,
10571057 self ).__init__ (num_features , momentum , epsilon , weight_attr ,
10581058 bias_attr , data_format , None , name )
10591059
1060- def _check_data_format (self , input ):
1061- if input == 'NCHW' or input == 'NCDHW' or input == 'NC' or input == 'NCL' :
1060+ def _check_data_format (self ):
1061+ if self . _data_format in [ 'NCHW' , 'NCDHW' , 'NC' , 'NCL' ] :
10621062 self ._data_format = 'NCHW'
1063- elif input == "NHWC" or input == "NDHWC" or input == 'NLC' :
1063+ elif self . _data_format in [ "NHWC" , "NDHWC" , 'NLC' ] :
10641064 self ._data_format = 'NHWC'
10651065 else :
10661066 raise ValueError (
1067- 'expected NCDHW, NDHWC, NCL, NLC, NC, NCHW, NHWC or None for data_format input '
1067+ 'expected \' NCDHW\' , \' NDHWC\' , \' NCL\' , \' NLC\' , \' NC \' , \' NCHW\' , \' NHWC\' for data_format'
10681068 )
10691069
10701070 def forward (self , x ):
1071- self ._check_data_format (self . _data_format )
1071+ self ._check_data_format ()
10721072 # create output
10731073 # mean and mean_out share the same memory
10741074 mean_out = self ._mean
@@ -1158,7 +1158,7 @@ def convert_sync_batchnorm(cls, layer):
11581158 bool ) and layer ._weight_attr .name != None :
11591159 layer ._weight_attr .name = layer ._weight_attr .name + '_sync'
11601160 if layer ._bias_attr != None and not isinstance (
1161- layer ._weight_attr , bool ) and layer ._bias_attr .name != None :
1161+ layer ._bias_attr , bool ) and layer ._bias_attr .name != None :
11621162 layer ._bias_attr .name = layer ._bias_attr .name + '_sync'
11631163
11641164 layer_output = SyncBatchNorm (layer ._num_features , layer ._momentum ,
0 commit comments