From 10994084333d9ed1228f39594ac2a9cfe929740f Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Sun, 25 Apr 2021 10:20:51 +0000 Subject: [PATCH 1/7] relu supports bfloat16 data type, test=develop --- paddle/fluid/operators/activation_op.cu | 24 ++++++++++- .../tests/unittests/test_activation_op.py | 43 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 04f329088fafe8..9b934b414eeb65 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -11,6 +11,7 @@ limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/float16.h" @@ -456,7 +457,24 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ -REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor); +REGISTER_OP_CUDA_KERNEL( + relu, ops::ActivationGPUKernel>, + ops::ActivationGPUKernel>, + ops::ActivationGPUKernel>, + ops::ActivationGPUKernel>); +REGISTER_OP_CUDA_KERNEL( + relu_grad, ops::ActivationGradGPUKernel>, + ops::ActivationGradGPUKernel>, + ops::ActivationGradGPUKernel>, + ops::ActivationGradGPUKernel>); REGISTER_OP_CUDA_KERNEL( relu_grad_grad, @@ -465,7 +483,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel>); + ops::ReluGradGradFunctor>, + ops::ActivationDoubleGradKernel>); /* ========================================================================== */ /* =========================== sqrt register ============================= */ diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index ea183e9444878d..53bf04ef407b46 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1051,6 +1051,24 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestReluBF16(TestActivation): + def setUp(self): + self.op_type = "relu" + self.dtype = core.VarDesc.VarType.BF16 + + np.random.seed(1024) + x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + class TestReluAPI(unittest.TestCase): # test paddle.nn.ReLU, paddle.nn.functional.relu def setUp(self): @@ -2672,5 +2690,30 @@ def test_check_grad(self): create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestHardSwish) + +#------------------ Test BF16 ---------------------- +def create_test_act_bf16_class(parent, + atol=1e-2, + grad_check=True, + grad_atol=0.80): + @unittest.skipIf(not paddle.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestActBF16(parent): + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=atol) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=grad_atol) + + cls_name = "{0}_{1}".format(parent.__name__, "bf16") + TestActBF16.__name__ = cls_name + globals()[cls_name] = TestActBF16 + + +create_test_act_bf16_class(TestReluBF16) + if __name__ == "__main__": unittest.main() From 1cbdb1b4ab054d1d331259cf18ace236ba38f69f Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Sun, 25 Apr 2021 17:12:32 +0000 Subject: [PATCH 2/7] fix relu unittest with bfloat16 data type, test=develop --- paddle/fluid/operators/fill_constant_op.cu.cc | 1 + paddle/fluid/operators/mean_op.cu | 6 +++-- .../paddle/fluid/tests/unittests/op_test.py | 23 ++++++++++++++++--- .../tests/unittests/test_activation_op.py | 23 ++++--------------- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc index e784c20b8b8b4f..e1b3fb14d794b1 100644 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -22,5 +22,6 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel); diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 430036bc67de70..a0504009fd41b0 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -107,10 +107,12 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanCUDAKernel, ops::MeanCUDAKernel, - ops::MeanCUDAKernel); + ops::MeanCUDAKernel, + ops::MeanCUDAKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanCUDAGradKernel, ops::MeanCUDAGradKernel, + ops::MeanCUDAGradKernel, ops::MeanCUDAGradKernel); + plat::bfloat16>); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 25717b79677128..48f9670e9db56d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -132,6 +132,8 @@ def product(dim): tensor_to_check_dtype = np.float16 # set delta as np.float16, will automatic convert to float32, float64 delta = np.array(delta).astype(np.float16) + elif tensor_to_check_dtype == core.VarDesc.VarType.BF16: + tensor_to_check_dtype = np.float32 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -140,9 +142,10 @@ def get_output(): sum = [] op.run(scope, place) for output_name in output_names: - sum.append( - np.array(scope.find_var(output_name).get_tensor()).astype( - tensor_to_check_dtype).mean()) + output_numpy = np.array(scope.find_var(output_name).get_tensor()) + if tensor_to_check._dtype() == core.VarDesc.VarType.BF16: + output_numpy = convert_uint16_to_float(output_numpy) + sum.append(output_numpy.astype(tensor_to_check_dtype).mean()) return tensor_to_check_dtype(np.array(sum).sum() / len(output_names)) gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype) @@ -152,6 +155,11 @@ def __get_elem__(tensor, i): numpy_tensor = np.array(tensor).astype(np.float16) numpy_tensor = numpy_tensor.flatten() return numpy_tensor[i] + elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16: + numpy_tensor = np.array(tensor).astype(np.uint16) + numpy_tensor = numpy_tensor.flatten() + return struct.unpack(' 1e-10, abs_a <= 1e-8)] *= 1e4 abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 + elif self.is_bfloat16_op(): + abs_a[abs_a < 1e-2] = 1 else: abs_a[abs_a < 1e-3] = 1 diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 20c2bd7afce04e..f6c6defe79b684 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1117,24 +1117,6 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') -class TestReluBF16(TestActivation): - def setUp(self): - self.op_type = "relu" - self.dtype = core.VarDesc.VarType.BF16 - - np.random.seed(1024) - x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32) - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0) - - self.inputs = {'X': x} - self.outputs = {'Out': out} - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - class TestReluAPI(unittest.TestCase): # test paddle.nn.ReLU, paddle.nn.functional.relu def setUp(self): @@ -2766,6 +2748,9 @@ def create_test_act_bf16_class(parent, @unittest.skipIf(not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestActBF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place(place, atol=atol) @@ -2780,7 +2765,7 @@ def test_check_grad(self): globals()[cls_name] = TestActBF16 -create_test_act_bf16_class(TestReluBF16) +create_test_act_bf16_class(TestRelu) if __name__ == "__main__": unittest.main() From 7a26ab6f88a94111c086a960e7cbcb90189d7aa4 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Mon, 26 Apr 2021 03:11:50 +0000 Subject: [PATCH 3/7] fix compilation on ROCM, test=develop --- paddle/fluid/operators/activation_op.cu | 4 ++++ python/paddle/fluid/tests/unittests/op_test.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 9b5548d9173e98..56707ae4ecd7a1 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -457,6 +457,9 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ +#ifdef PADDLE_WITH_HIP +REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor); +#else REGISTER_OP_CUDA_KERNEL( relu, ops::ActivationGPUKernel>, @@ -475,6 +478,7 @@ REGISTER_OP_CUDA_KERNEL( ops::ReluGradGPUFunctor>, ops::ActivationGradGPUKernel>); +#endif REGISTER_OP_CUDA_KERNEL( relu_grad_grad, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 48f9670e9db56d..dbed41c951e917 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1363,7 +1363,7 @@ def _assert_is_close(self, numeric_grads, analytic_grads, names, abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4 abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 elif self.is_bfloat16_op(): - abs_a[abs_a < 1e-2] = 1 + abs_a[abs_a < 3e-3] = 1 else: abs_a[abs_a < 1e-3] = 1 From acf89ba371f2c5d935c2ba2d545bf3e845cb61f0 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Mon, 10 May 2021 12:05:30 +0000 Subject: [PATCH 4/7] fix tests, test=develop --- paddle/fluid/operators/activation_op.cu | 32 ++++++------- paddle/fluid/operators/cast_op.cu | 14 +++++- paddle/fluid/operators/fill_constant_op.cu.cc | 1 - paddle/fluid/operators/mean_op.cu | 6 +-- .../paddle/fluid/tests/unittests/op_test.py | 46 +++++++++++++++---- .../tests/unittests/test_activation_op.py | 20 +++++--- 6 files changed, 81 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index d7567f1cf258f0..78427c460feceb 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1443,23 +1443,23 @@ REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor, CudaReluGradFunctor); #else REGISTER_OP_CUDA_KERNEL( - relu, ops::ActivationGPUKernel>, - ops::ActivationGPUKernel>, - ops::ActivationGPUKernel>, - ops::ActivationGPUKernel>); + relu, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( - relu_grad, ops::ActivationGradGPUKernel>, - ops::ActivationGradGPUKernel>, - ops::ActivationGradGPUKernel>, - ops::ActivationGradGPUKernel>); + relu_grad, ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); #endif REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 13759633d0168a..9cc1f8a841e370 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -95,6 +95,12 @@ struct CastOpFunctor { namespace ops = paddle::operators; +#ifdef PADDLE_WITH_HIP +#define HIDDEN(...) +#else +#define HIDDNE(...) __VA_ARGS__, +#endif + REGISTER_OP_CUDA_KERNEL( cast, ops::CastOpKernel, ops::CastOpKernel, @@ -104,7 +110,11 @@ REGISTER_OP_CUDA_KERNEL( ops::CastOpKernel, ops::CastOpKernel, - ops::CastOpKernel, + HIDDNE(ops::CastOpKernel) + ops::CastOpKernel, ops::CastOpKernel); + +#undef HIDDNE diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc index e1b3fb14d794b1..e784c20b8b8b4f 100644 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -22,6 +22,5 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, - ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel); diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index a0504009fd41b0..430036bc67de70 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -107,12 +107,10 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanCUDAKernel, ops::MeanCUDAKernel, - ops::MeanCUDAKernel, - ops::MeanCUDAKernel); + ops::MeanCUDAKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanCUDAGradKernel, ops::MeanCUDAGradKernel, - ops::MeanCUDAGradKernel, ops::MeanCUDAGradKernel); + plat::float16>); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index dbed41c951e917..afcfea984252ef 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1363,7 +1363,7 @@ def _assert_is_close(self, numeric_grads, analytic_grads, names, abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4 abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 elif self.is_bfloat16_op(): - abs_a[abs_a < 3e-3] = 1 + abs_a[abs_a < 1e-2] = 1 else: abs_a[abs_a < 1e-3] = 1 @@ -1495,13 +1495,9 @@ def check_grad_with_place(self, # comparison of bf16 results will happen as fp32 # loop over list of grads and convert bf16 to fp32 - fp32_grads = [] - for grad in analytic_grads: - if grad.dtype == np.uint16: - grad = convert_uint16_to_float(grad) - max_relative_error = 0.03 - fp32_grads.append(grad) - analytic_grads = fp32_grads + if self.is_bfloat16_op(): + max_relative_error = 0.03 + analytic_grads = list(map(convert_uint16_to_float, analytic_grads)) self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, @@ -1511,6 +1507,9 @@ def check_grad_with_place(self, dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, user_defined_grad_outputs, no_grad_set) + if self.is_bfloat16_op(): + max_relative_error = 0.03 + dygraph_grad = list(map(convert_uint16_to_float, dygraph_grad)) self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check, max_relative_error, "Gradient Check On %s" % str(place)) @@ -1555,6 +1554,21 @@ def _get_dygraph_grad(self, outputs=outputs, attrs=attrs_outputs if hasattr(self, "attrs") else None) + if self.is_bfloat16_op(): + cast_inputs = self._find_var_in_dygraph(outputs, + output_names[0]) + cast_outputs = block.create_var( + dtype="float32", shape=cast_inputs[0].shape) + cast_op = block.append_op( + inputs={"X": cast_inputs}, + outputs={"Out": cast_outputs}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32 + }) + outputs = {output_names[0]: cast_outputs} + outputs_valid = {} for output_name in output_names: outputs_valid[output_name] = self._find_var_in_dygraph( @@ -1669,6 +1683,22 @@ def _get_gradient(self, outputs = self._get_outputs(block) feed_dict = self.feed_var(inputs, place) + if self.is_bfloat16_op(): + cast_inputs = list(map(block.var, output_names)) + cast_outputs = block.create_var( + dtype="float32", shape=cast_inputs[0].shape) + cast_op = block.append_op( + inputs={"X": cast_inputs}, + outputs={"Out": cast_outputs}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32 + }) + cast_op.desc.infer_var_type(block.desc) + cast_op.desc.infer_shape(block.desc) + output_names = [cast_outputs.name] + if user_defined_grad_outputs is None: loss = append_loss_ops(block, output_names) param_grad_list = append_backward( diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index f4c4f8b55fad61..ef5ac46cede42c 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -18,7 +18,7 @@ import numpy as np from scipy.special import expit, erf -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -1103,12 +1103,19 @@ def setUp(self): self.init_dtype() np.random.seed(1024) - x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0) + if self.dtype == np.uint16: + x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = convert_float_to_uint16(np.maximum(x, 0)) + self.inputs = {'X': convert_float_to_uint16(x)} + else: + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + self.inputs = {'X': x} - self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -2740,7 +2747,6 @@ def test_check_grad(self): create_test_act_fp16_class(TestHardSwish) -#------------------ Test BF16 ---------------------- def create_test_act_bf16_class(parent, atol=1e-2, grad_check=True, From 08ef559e6614d088bb1149ea8be4d2c7116ddd59 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Mon, 10 May 2021 14:46:41 +0000 Subject: [PATCH 5/7] fix compilation error, test=develop --- paddle/fluid/operators/activation_op.cu | 11 ++++++-- paddle/fluid/operators/cast_op.cu | 28 ++++++++++++------- .../paddle/fluid/tests/unittests/op_test.py | 8 +++--- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 78427c460feceb..b3a421244a690b 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1441,6 +1441,14 @@ REGISTER_OP_CUDA_KERNEL( #ifdef PADDLE_WITH_HIP REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor, CudaReluGradFunctor); +REGISTER_OP_CUDA_KERNEL( + relu_grad_grad, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>); #else REGISTER_OP_CUDA_KERNEL( relu, ops::ActivationCudaKernel>, ops::ActivationGradCudaKernel>); -#endif - REGISTER_OP_CUDA_KERNEL( relu_grad_grad, ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel>); +#endif /* ========================================================================== */ /* =========================== tanh register ============================ */ diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 9cc1f8a841e370..2ef5b9ae3ac373 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -96,11 +96,20 @@ struct CastOpFunctor { namespace ops = paddle::operators; #ifdef PADDLE_WITH_HIP -#define HIDDEN(...) +REGISTER_OP_CUDA_KERNEL( + cast, ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel); #else -#define HIDDNE(...) __VA_ARGS__, -#endif - REGISTER_OP_CUDA_KERNEL( cast, ops::CastOpKernel, ops::CastOpKernel, @@ -110,11 +119,10 @@ REGISTER_OP_CUDA_KERNEL( ops::CastOpKernel, ops::CastOpKernel, - HIDDNE(ops::CastOpKernel) - ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel); - -#undef HIDDNE +#endif diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index afcfea984252ef..c7e76a4a4d030d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1495,7 +1495,7 @@ def check_grad_with_place(self, # comparison of bf16 results will happen as fp32 # loop over list of grads and convert bf16 to fp32 - if self.is_bfloat16_op(): + if self.dtype == np.uint16: max_relative_error = 0.03 analytic_grads = list(map(convert_uint16_to_float, analytic_grads)) @@ -1507,7 +1507,7 @@ def check_grad_with_place(self, dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, user_defined_grad_outputs, no_grad_set) - if self.is_bfloat16_op(): + if self.dtype == np.uint16: max_relative_error = 0.03 dygraph_grad = list(map(convert_uint16_to_float, dygraph_grad)) self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check, @@ -1554,7 +1554,7 @@ def _get_dygraph_grad(self, outputs=outputs, attrs=attrs_outputs if hasattr(self, "attrs") else None) - if self.is_bfloat16_op(): + if self.dtype == np.uint16: cast_inputs = self._find_var_in_dygraph(outputs, output_names[0]) cast_outputs = block.create_var( @@ -1683,7 +1683,7 @@ def _get_gradient(self, outputs = self._get_outputs(block) feed_dict = self.feed_var(inputs, place) - if self.is_bfloat16_op(): + if self.dtype == np.uint16: cast_inputs = list(map(block.var, output_names)) cast_outputs = block.create_var( dtype="float32", shape=cast_inputs[0].shape) From f70adad88af44b84afe4f0a35dc64a6b1526bfb8 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Tue, 11 May 2021 08:07:15 +0000 Subject: [PATCH 6/7] fix bf16 cpu unittests, test=develop --- .../paddle/fluid/tests/unittests/op_test.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index c7e76a4a4d030d..3222658b246e3d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1495,9 +1495,13 @@ def check_grad_with_place(self, # comparison of bf16 results will happen as fp32 # loop over list of grads and convert bf16 to fp32 - if self.dtype == np.uint16: - max_relative_error = 0.03 - analytic_grads = list(map(convert_uint16_to_float, analytic_grads)) + fp32_grads = [] + for grad in analytic_grads: + if grad.dtype == np.uint16: + grad = convert_uint16_to_float(grad) + max_relative_error = 0.03 + fp32_grads.append(grad) + analytic_grads = fp32_grads self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, @@ -1507,9 +1511,13 @@ def check_grad_with_place(self, dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, user_defined_grad_outputs, no_grad_set) - if self.dtype == np.uint16: - max_relative_error = 0.03 - dygraph_grad = list(map(convert_uint16_to_float, dygraph_grad)) + fp32_grads = [] + for grad in dygraph_grad: + if grad.dtype == np.uint16: + grad = convert_uint16_to_float(grad) + max_relative_error = 0.03 + fp32_grads.append(grad) + dygraph_grad = fp32_grads self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check, max_relative_error, "Gradient Check On %s" % str(place)) @@ -1683,23 +1691,22 @@ def _get_gradient(self, outputs = self._get_outputs(block) feed_dict = self.feed_var(inputs, place) - if self.dtype == np.uint16: - cast_inputs = list(map(block.var, output_names)) - cast_outputs = block.create_var( - dtype="float32", shape=cast_inputs[0].shape) - cast_op = block.append_op( - inputs={"X": cast_inputs}, - outputs={"Out": cast_outputs}, - type="cast", - attrs={ - "in_dtype": core.VarDesc.VarType.BF16, - "out_dtype": core.VarDesc.VarType.FP32 - }) + if user_defined_grad_outputs is None: + if self.dtype == np.uint16: + cast_inputs = list(map(block.var, output_names)) + cast_outputs = block.create_var( + dtype="float32", shape=cast_inputs[0].shape) + cast_op = block.append_op( + inputs={"X": cast_inputs}, + outputs={"Out": cast_outputs}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32 + }) cast_op.desc.infer_var_type(block.desc) cast_op.desc.infer_shape(block.desc) output_names = [cast_outputs.name] - - if user_defined_grad_outputs is None: loss = append_loss_ops(block, output_names) param_grad_list = append_backward( loss=loss, From 394e394eee5b1dd5e085f26526fd0647599dbe81 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Tue, 11 May 2021 11:31:13 +0000 Subject: [PATCH 7/7] fix unittests error, test=develop --- python/paddle/fluid/tests/unittests/op_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 3222658b246e3d..6451a36c03b89d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1704,9 +1704,9 @@ def _get_gradient(self, "in_dtype": core.VarDesc.VarType.BF16, "out_dtype": core.VarDesc.VarType.FP32 }) - cast_op.desc.infer_var_type(block.desc) - cast_op.desc.infer_shape(block.desc) - output_names = [cast_outputs.name] + cast_op.desc.infer_var_type(block.desc) + cast_op.desc.infer_shape(block.desc) + output_names = [cast_outputs.name] loss = append_loss_ops(block, output_names) param_grad_list = append_backward( loss=loss,