-
Notifications
You must be signed in to change notification settings - Fork 6k
conv2d support bfloat16 #32221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
conv2d support bfloat16 #32221
Changes from 6 commits
687e28b
252dbcf
747d096
8bab3d7
74bb02d
cd612c5
4069f78
7d3a4d5
c41fe74
f3ca4b8
5a3e730
f23f1d2
15f3315
394d8d4
12cc70b
62fcd51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -131,6 +131,27 @@ inline ActivationMode StringToActivationMode(const std::string& str) { | |||||||||||
| template <typename T> | ||||||||||||
| class CudnnDataType; | ||||||||||||
|
|
||||||||||||
| template <> | ||||||||||||
| class CudnnDataType<bfloat16> { | ||||||||||||
| public: | ||||||||||||
| // CUDNN_DATA_BFLOAT16 is not valid before cudnn8.1 | ||||||||||||
| #if CUDNN_VERSION_MIN(8, 1, 0) | ||||||||||||
| static const cudnnDataType_t type = CUDNN_DATA_BFLOAT16; | ||||||||||||
| #else | ||||||||||||
| static const cudnnDataType_t type = CUDNN_DATA_HALF; | ||||||||||||
| #endif | ||||||||||||
|
||||||||||||
| if (input_data_type == framework::proto::VarType::FP16) { | |
| PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, | |
| platform::errors::InvalidArgument( | |
| "float16 can only be used when CUDNN is used")); | |
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -167,6 +167,37 @@ def test_check_grad_no_input(self): | |
| globals()[cls_name] = TestConv2DCUDNNFp16 | ||
|
|
||
|
|
||
| def create_test_cudnn_bf16_class(parent, grad_check=True): | ||
|
||
| @unittest.skipIf( | ||
| not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, | ||
| "core is not compiled with CUDA and cudnn version need larger than 8.1.0" | ||
| ) | ||
| class TestConv2DCUDNNBF16(parent): | ||
| def init_kernel_type(self): | ||
| self.use_cudnn = True | ||
| self.dtype = np.uint16 | ||
|
|
||
| def test_check_output(self): | ||
| place = core.CUDAPlace(0) | ||
| self.check_output_with_place(place, atol=1e-2) | ||
|
|
||
| def test_check_grad_no_filter(self): | ||
| place = core.CUDAPlace(0) | ||
| if grad_check: | ||
| self.check_grad_with_place( | ||
| place, ['Input'], 'Output', no_grad_set=set(['Filter'])) | ||
|
|
||
| def test_check_grad_no_input(self): | ||
| place = core.CUDAPlace(0) | ||
| if grad_check: | ||
| self.check_grad_with_place( | ||
| place, ['Filter'], 'Output', no_grad_set=set(['Input'])) | ||
|
|
||
| cls_name = "{0}_{1}".format(parent.__name__, "CUDNNBF16") | ||
| TestConv2DCUDNNBF16.__name__ = cls_name | ||
| globals()[cls_name] = TestConv2DCUDNNBF16 | ||
|
|
||
|
|
||
| def create_test_channel_last_class(parent): | ||
| class TestChannelLastCase(parent): | ||
| def init_data_format(self): | ||
|
|
@@ -554,6 +585,15 @@ def init_group(self): | |
| create_test_cudnn_fp16_class(TestWith1x1, grad_check=False) | ||
| create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False) | ||
|
|
||
| #----------------Conv2DCUDNN bf16---------------- | ||
|
|
||
| create_test_cudnn_bf16_class(TestConv2DOp, grad_check=False) | ||
| create_test_cudnn_bf16_class(TestWithPad, grad_check=False) | ||
| create_test_cudnn_bf16_class(TestWithStride, grad_check=False) | ||
| create_test_cudnn_bf16_class(TestWithGroup, grad_check=False) | ||
| create_test_cudnn_bf16_class(TestWith1x1, grad_check=False) | ||
| create_test_cudnn_bf16_class(TestWithInput1x1Filter1x1, grad_check=False) | ||
|
||
|
|
||
| #----------------TestDepthwiseConv ----- | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个检查能不能放到一个公共的地方,比如
CudnnDataType<bfloat16>里面?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CudnnDataType<bfloat16>里只能做编译期检查,这里直接改为cudnn8.1以下不添加bfloat16数据类型的Kernel。