Skip to content

Commit 0620f59

Browse files
authored
Merge pull request PaddlePaddle#33 from carlushuang/0.15.0_1b7e38
add fp16 support for conv/bn/pool/softmax
2 parents 119022a + 2aa8260 commit 0620f59

File tree

6 files changed

+20
-8
lines changed

6 files changed

+20
-8
lines changed

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
317317
namespace ops = paddle::operators;
318318
namespace plat = paddle::platform;
319319
REGISTER_OP_CUDA_KERNEL(
320-
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>);
320+
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
321+
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
321322
REGISTER_OP_CUDA_KERNEL(
322323
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
326326

327327
namespace plat = paddle::platform;
328328
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
329-
paddle::operators::CUDNNConvOpKernel<float>);
329+
paddle::operators::CUDNNConvOpKernel<float>,
330+
paddle::operators::CUDNNConvOpKernel<plat::float16>);
330331
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
331332
paddle::operators::CUDNNConvGradOpKernel<float>);
332333

paddle/fluid/operators/elementwise_add_op.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ namespace plat = paddle::platform;
2222
REGISTER_OP_CUDA_KERNEL(
2323
elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
2424
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
25-
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>);
25+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
26+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
27+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, paddle::platform::float16>);
2628
REGISTER_OP_CUDA_KERNEL(
2729
elementwise_add_grad,
2830
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
2931
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
30-
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>);
32+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
33+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>);

paddle/fluid/operators/math/softmax.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,16 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
9090
}
9191

9292
template class SoftmaxCUDNNFunctor<float>;
93+
template class SoftmaxCUDNNFunctor<platform::float16>;
9394
template class SoftmaxGradCUDNNFunctor<float>;
95+
template class SoftmaxGradCUDNNFunctor<platform::float16>;
9496

9597
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
9698
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
9799
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
98100
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
99101
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
102+
template class SoftmaxGradFunctor<platform::CUDADeviceContext, platform::float16>;
100103

101104
} // namespace math
102105
} // namespace operators

paddle/fluid/operators/pool_cudnn_op.cu.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,11 @@ namespace ops = paddle::operators;
188188
namespace plat = paddle::platform;
189189

190190
REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
191-
ops::PoolCUDNNOpKernel<float>);
191+
ops::PoolCUDNNOpKernel<float>,
192+
ops::PoolCUDNNOpKernel<plat::float16>);
192193
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace,
193-
ops::PoolCUDNNGradOpKernel<float>);
194+
ops::PoolCUDNNGradOpKernel<float>,
195+
ops::PoolCUDNNGradOpKernel<plat::float16>);
194196

195197
REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace,
196198
ops::PoolCUDNNOpKernel<float>);

paddle/fluid/operators/softmax_cudnn_op.cu.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
7575
namespace ops = paddle::operators;
7676
namespace plat = paddle::platform;
7777
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
78-
ops::SoftmaxCUDNNKernel<float>);
78+
ops::SoftmaxCUDNNKernel<float>,
79+
ops::SoftmaxCUDNNKernel<plat::float16>);
7980
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
80-
ops::SoftmaxGradCUDNNKernel<float>);
81+
ops::SoftmaxGradCUDNNKernel<float>,
82+
ops::SoftmaxCUDNNKernel<plat::float16>);

0 commit comments

Comments
 (0)