-
Notifications
You must be signed in to change notification settings - Fork 5.9k
relu supports bfloat16 data type #32542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1099408
fb7b924
1cbdb1b
7a26ab6
f6cb377
acf89ba
08ef559
f70adad
394e394
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,7 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> { | |
|
|
||
| namespace ops = paddle::operators; | ||
|
|
||
| #ifdef PADDLE_WITH_HIP | ||
| REGISTER_OP_CUDA_KERNEL( | ||
| cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>, | ||
|
|
@@ -108,3 +109,20 @@ REGISTER_OP_CUDA_KERNEL( | |
| paddle::platform::complex64>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, | ||
| paddle::platform::complex128>); | ||
| #else | ||
| REGISTER_OP_CUDA_KERNEL( | ||
| cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, | ||
| paddle::platform::float16>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, | ||
| paddle::platform::bfloat16>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, | ||
| paddle::platform::complex64>, | ||
| ops::CastOpKernel<paddle::platform::CUDADeviceContext, | ||
| paddle::platform::complex128>); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 注册代码后续PR中考虑简化一下。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的。 |
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,8 @@ def product(dim): | |
| tensor_to_check_dtype = np.float16 | ||
| # set delta as np.float16, will automatic convert to float32, float64 | ||
| delta = np.array(delta).astype(np.float16) | ||
| elif tensor_to_check_dtype == core.VarDesc.VarType.BF16: | ||
| tensor_to_check_dtype = np.float32 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据L127~L132,这里应该是将paddle的数据类型转换为对应的numpy的数据类型。所以
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BF16类型也是属于浮点型,这里为了单测代码更易读,选择使用float32而不是uint16作为结果检查的类型。 |
||
| else: | ||
| raise ValueError("Not supported data type " + str( | ||
| tensor_to_check_dtype)) | ||
|
|
@@ -140,9 +142,10 @@ def get_output(): | |
| sum = [] | ||
| op.run(scope, place) | ||
| for output_name in output_names: | ||
| sum.append( | ||
| np.array(scope.find_var(output_name).get_tensor()).astype( | ||
| tensor_to_check_dtype).mean()) | ||
| output_numpy = np.array(scope.find_var(output_name).get_tensor()) | ||
| if tensor_to_check._dtype() == core.VarDesc.VarType.BF16: | ||
| output_numpy = convert_uint16_to_float(output_numpy) | ||
| sum.append(output_numpy.astype(tensor_to_check_dtype).mean()) | ||
| return tensor_to_check_dtype(np.array(sum).sum() / len(output_names)) | ||
|
|
||
| gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype) | ||
|
|
@@ -152,6 +155,11 @@ def __get_elem__(tensor, i): | |
| numpy_tensor = np.array(tensor).astype(np.float16) | ||
| numpy_tensor = numpy_tensor.flatten() | ||
| return numpy_tensor[i] | ||
| elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16: | ||
| numpy_tensor = np.array(tensor).astype(np.uint16) | ||
| numpy_tensor = numpy_tensor.flatten() | ||
| return struct.unpack('<f', struct.pack('<I', numpy_tensor[i] | ||
| << 16))[0] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是将
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的。 |
||
| elif tensor_to_check_dtype == np.float32: | ||
| return tensor._get_float_element(i) | ||
| elif tensor_to_check_dtype == np.float64: | ||
|
|
@@ -168,6 +176,13 @@ def __set_elem__(tensor, i, e): | |
| numpy_tensor[i] = e | ||
| numpy_tensor = numpy_tensor.reshape(shape) | ||
| tensor.set(numpy_tensor, place) | ||
| elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16: | ||
| numpy_tensor = np.array(tensor).astype(np.uint16) | ||
| shape = numpy_tensor.shape | ||
| numpy_tensor = numpy_tensor.flatten() | ||
| numpy_tensor[i] = np.uint16(copy_bits_from_float_to_uint16(e)) | ||
| numpy_tensor = numpy_tensor.reshape(shape) | ||
| tensor.set(numpy_tensor, place) | ||
| elif tensor_to_check_dtype == np.float32: | ||
| tensor._set_float_element(i, e) | ||
| elif tensor_to_check_dtype == np.float64: | ||
|
|
@@ -1347,6 +1362,8 @@ def _assert_is_close(self, numeric_grads, analytic_grads, names, | |
| abs_a[abs_a < 1e-10] = 1e-3 | ||
| abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4 | ||
| abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 | ||
| elif self.is_bfloat16_op(): | ||
| abs_a[abs_a < 1e-2] = 1 | ||
| else: | ||
| abs_a[abs_a < 1e-3] = 1 | ||
|
|
||
|
|
@@ -1494,6 +1511,13 @@ def check_grad_with_place(self, | |
| dygraph_grad = self._get_dygraph_grad( | ||
| inputs_to_check, place, output_names, user_defined_grad_outputs, | ||
| no_grad_set) | ||
| fp32_grads = [] | ||
| for grad in dygraph_grad: | ||
| if grad.dtype == np.uint16: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 语义上是 下同。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的。 |
||
| grad = convert_uint16_to_float(grad) | ||
| max_relative_error = 0.03 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 设大了不通过吗?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前与cpu上bf16的设置逻辑先保持一致。 |
||
| fp32_grads.append(grad) | ||
| dygraph_grad = fp32_grads | ||
| self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check, | ||
| max_relative_error, | ||
| "Gradient Check On %s" % str(place)) | ||
|
|
@@ -1538,6 +1562,21 @@ def _get_dygraph_grad(self, | |
| outputs=outputs, | ||
| attrs=attrs_outputs if hasattr(self, "attrs") else None) | ||
|
|
||
| if self.dtype == np.uint16: | ||
| cast_inputs = self._find_var_in_dygraph(outputs, | ||
| output_names[0]) | ||
| cast_outputs = block.create_var( | ||
| dtype="float32", shape=cast_inputs[0].shape) | ||
| cast_op = block.append_op( | ||
| inputs={"X": cast_inputs}, | ||
| outputs={"Out": cast_outputs}, | ||
| type="cast", | ||
| attrs={ | ||
| "in_dtype": core.VarDesc.VarType.BF16, | ||
| "out_dtype": core.VarDesc.VarType.FP32 | ||
| }) | ||
| outputs = {output_names[0]: cast_outputs} | ||
|
|
||
| outputs_valid = {} | ||
| for output_name in output_names: | ||
| outputs_valid[output_name] = self._find_var_in_dygraph( | ||
|
|
@@ -1653,6 +1692,21 @@ def _get_gradient(self, | |
| feed_dict = self.feed_var(inputs, place) | ||
|
|
||
| if user_defined_grad_outputs is None: | ||
| if self.dtype == np.uint16: | ||
| cast_inputs = list(map(block.var, output_names)) | ||
| cast_outputs = block.create_var( | ||
| dtype="float32", shape=cast_inputs[0].shape) | ||
| cast_op = block.append_op( | ||
| inputs={"X": cast_inputs}, | ||
| outputs={"Out": cast_outputs}, | ||
| type="cast", | ||
| attrs={ | ||
| "in_dtype": core.VarDesc.VarType.BF16, | ||
| "out_dtype": core.VarDesc.VarType.FP32 | ||
| }) | ||
| cast_op.desc.infer_var_type(block.desc) | ||
| cast_op.desc.infer_shape(block.desc) | ||
| output_names = [cast_outputs.name] | ||
| loss = append_loss_ops(block, output_names) | ||
| param_grad_list = append_backward( | ||
| loss=loss, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| import numpy as np | ||
| from scipy.special import expit, erf | ||
|
|
||
| from op_test import OpTest | ||
| from op_test import OpTest, convert_float_to_uint16 | ||
| import paddle | ||
| import paddle.nn as nn | ||
| import paddle.nn.functional as F | ||
|
|
@@ -1103,12 +1103,19 @@ def setUp(self): | |
| self.init_dtype() | ||
|
|
||
| np.random.seed(1024) | ||
| x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) | ||
| # The same reason with TestAbs | ||
| x[np.abs(x) < 0.005] = 0.02 | ||
| out = np.maximum(x, 0) | ||
| if self.dtype == np.uint16: | ||
| x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32) | ||
| # The same reason with TestAbs | ||
| x[np.abs(x) < 0.005] = 0.02 | ||
| out = convert_float_to_uint16(np.maximum(x, 0)) | ||
| self.inputs = {'X': convert_float_to_uint16(x)} | ||
| else: | ||
| x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) | ||
| # The same reason with TestAbs | ||
| x[np.abs(x) < 0.005] = 0.02 | ||
| out = np.maximum(x, 0) | ||
| self.inputs = {'X': x} | ||
|
|
||
| self.inputs = {'X': x} | ||
| self.outputs = {'Out': out} | ||
|
|
||
| def test_check_grad(self): | ||
|
|
@@ -2739,5 +2746,32 @@ def test_check_grad(self): | |
| create_test_act_fp16_class(TestSwish, grad_atol=0.85) | ||
| create_test_act_fp16_class(TestHardSwish) | ||
|
|
||
|
|
||
| def create_test_act_bf16_class(parent, | ||
| atol=1e-2, | ||
| grad_check=True, | ||
| grad_atol=0.80): | ||
| @unittest.skipIf(not paddle.is_compiled_with_cuda(), | ||
| "core is not compiled with CUDA") | ||
| class TestActBF16(parent): | ||
| def init_dtype(self): | ||
| self.dtype = np.uint16 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 设置dtype为
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前测试框架内使用 |
||
|
|
||
| def test_check_output(self): | ||
| place = core.CUDAPlace(0) | ||
| self.check_output_with_place(place, atol=atol) | ||
|
|
||
| def test_check_grad(self): | ||
| place = core.CUDAPlace(0) | ||
| self.check_grad_with_place( | ||
| place, ['X'], 'Out', max_relative_error=grad_atol) | ||
|
|
||
| cls_name = "{0}_{1}".format(parent.__name__, "bf16") | ||
| TestActBF16.__name__ = cls_name | ||
| globals()[cls_name] = TestActBF16 | ||
|
|
||
|
|
||
| create_test_act_bf16_class(TestRelu) | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double grad验证过吗?如果没有的话,先不要注册。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测中已验证。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注册代码后续PR中还是考虑简化一下。