Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions paddle/fluid/operators/conv_cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,31 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
#if CUDNN_VERSION_MIN(8, 1, 0)
} else if (dev_ctx.GetComputeCapability() >= 80 &&
dtype == CUDNN_DATA_BFLOAT16) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
#endif // CUDNN_VERSION >= 8100
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_FMA_MATH));
VLOG(5) << "use cudnn_fma_math";
#endif // CUDA_VERSION >= 11000
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH));
VLOG(5) << "use cudnn_default_math";
}
#endif

Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/operators/conv_cudnn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,31 @@ REGISTER_OP_KERNEL(
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>);
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注册的代码超过100+行了,可以简化下。这些注册无非3种类型:

  • CUDA,CUDNN < 8.1,支持float、double、float16
  • CUDA,CUDNN >= 8.1,支持float、double、float16、bfloat16
  • ROCM,支持float、float16

可以定义一些注册的宏,比如:REGISTER_CONV_CUDNN_KERNEL_WITH_FP64_BF16、REGISTER_CONV_CUDNN_KERNEL_WITH_FP64、REGISTER_CONV_CUDNN_KERNEL_WITH_BF16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,后续跟进。

REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel<plat::float16>,
paddle::operators::CUDNNConvOpKernel<plat::bfloat16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>,
paddle::operators::CUDNNConvGradOpKernel<plat::bfloat16>);
REGISTER_OP_KERNEL(
conv2d_grad_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
paddle::operators::CUDNNConvDoubleGradOpKernel<double>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::bfloat16>);

REGISTER_OP_CUDA_KERNEL(
depthwise_conv2d_grad_grad,
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
paddle::operators::CUDNNConvDoubleGradOpKernel<double>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::bfloat16>);
#else
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>,
Expand All @@ -1432,6 +1457,7 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
paddle::operators::CUDNNConvDoubleGradOpKernel<double>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>);
#endif

REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
platform::errors::InvalidArgument(
"float16 can only be used when CUDNN is used"));
}
#if PADDLE_WITH_CUDA
if (input_data_type == framework::proto::VarType::BF16 &&
library == framework::LibraryType::kCUDNN) {
PADDLE_ENFORCE_GE(
platform::CudnnVersion(), 8100,
platform::errors::InvalidArgument(
"bfloat16 can only be used when CUDNN_VERSION >= 8100"));
}
#endif // PADDLE_WITH_CUDA

auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/platform/cudnn_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ inline cudnnDataType_t ToCudnnDataType(
case framework::proto::VarType::FP64:
type = CUDNN_DATA_DOUBLE;
break;
#if CUDNN_VERSION_MIN(8, 1, 0)
case framework::proto::VarType::BF16:
type = CUDNN_DATA_BFLOAT16;
break;
#endif
default:
break;
}
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/platform/cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,25 @@ inline ActivationMode StringToActivationMode(const std::string& str) {
template <typename T>
class CudnnDataType;

// CUDNN_DATA_BFLOAT16 is not valid before cudnn8.1
#if CUDNN_VERSION_MIN(8, 1, 0)
template <>
class CudnnDataType<bfloat16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_BFLOAT16;
using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
#endif

template <>
class CudnnDataType<float16> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def set_confs(self):
def test_check_output(self):
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output(check_dygraph=False, no_check_set=["Cell"])
self.check_output(
check_dygraph=False, no_check_set=["Cell"], atol=2e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里指明了atol,因为你把op_test.py中的atol值改了。这样还是会影响到其他op的单测吧,我觉得最好不改op_test.py,重写OpTest函数就不会影响到其他op单测了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只影响bfloat16前向精度测试,之前单测框架中写死用0.03,PR的修改只是取消这种固定值,在有需要的各个单测中指定即可。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改了op单测中的精度检查方式,影响了mkldnn的单测,请@luotao1 review一下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In file python/paddle/fluid/tests/unittests/op_test.py, atol = 0.03 is not a good way to check forward accuracy. This PR modified the relative error of checking the accuracy of bfload16 data type and deleted the limitation of 0.03. And add atol = 2e-2 here to keep the same accuracy limit as before to ensure the test pass.


def setUp(self):
self.op_type = 'fusion_lstm'
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,9 @@ def find_actual(target_name, fetch_list):
np.float32, np.float64
]:
actual_t = convert_uint16_to_float(actual_t)
atol = max(atol, 0.03)
rtol = 1.e-2
else:
rtol = 1.e-5

if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
Expand All @@ -1204,7 +1206,11 @@ def find_actual(target_name, fetch_list):

self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
actual_t,
expect_t,
rtol=rtol,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__)
Expand Down
91 changes: 83 additions & 8 deletions python/paddle/fluid/tests/unittests/test_conv2d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from paddle.fluid.tests.unittests.testsuite import create_op
from paddle.fluid import Program, program_guard


Expand Down Expand Up @@ -167,6 +168,52 @@ def test_check_grad_no_input(self):
globals()[cls_name] = TestConv2DCUDNNFp16


def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0"
)
class TestConv2DCUDNNBF16(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(scope, self.op_type, self.inputs, self.outputs,
self.attrs)
return get_numeric_gradient(place, scope, op, self.inputs_fp32,
check_name, ['Output'])

def init_kernel_type(self):
self.use_cudnn = True
self.no_need_check_grad = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.no_need_check_grad = True还保留在,有什么影响吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个主要是防止父类里的单测被执行到。

self.dtype = np.uint16

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-2)

def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Input')
self.check_grad_with_place(
place, ['Input'],
'Output',
no_grad_set=set(['Filter']),
user_defined_grads=[numeric_grads])

def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place, ['Filter'],
'Output',
no_grad_set=set(['Input']),
user_defined_grads=[numeric_grads])

cls_name = "{0}_{1}".format(parent.__name__, "CUDNNBF16")
TestConv2DCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNBF16


def create_test_channel_last_class(parent):
class TestChannelLastCase(parent):
def init_data_format(self):
Expand Down Expand Up @@ -319,7 +366,15 @@ def setUp(self):
'dilation': self.dilations
}

input = np.random.random(self.input_size).astype(self.dtype)
if self.is_bfloat16_op():
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.uniform(-1, 1,
self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.uniform(-1, 1,
self.filter_size).astype(self.dtype)

if not self.has_cuda():
self.fuse_relu_before_depthwise_conv = False
if self.fuse_relu_before_depthwise_conv:
Expand All @@ -329,16 +384,27 @@ def setUp(self):
input2 = np.maximum(input, 0.0)
else:
input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype)

output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups,
conv2d_param)
output = output.astype(self.dtype)

self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
if self.is_bfloat16_op():
output = output.astype(np.float32)
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter)
}
self.inputs_fp32 = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是还构造了fp32的conv2d?在PR描述里面说明一下单测检查的逻辑吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
else:
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}

self.attrs = {
'strides': self.stride,
'paddings': self.pad,
Expand Down Expand Up @@ -554,6 +620,15 @@ def init_group(self):
create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)

#----------------Conv2DCUDNN bf16----------------

create_test_cudnn_bf16_class(TestConv2DOp)
create_test_cudnn_bf16_class(TestWithPad)
create_test_cudnn_bf16_class(TestWithStride)
create_test_cudnn_bf16_class(TestWithGroup)
create_test_cudnn_bf16_class(TestWith1x1)
create_test_cudnn_bf16_class(TestWithInput1x1Filter1x1)

#----------------TestDepthwiseConv -----


Expand Down