Skip to content

Commit 447a6e1

Browse files
committed
support double in deformable conv
1 parent 572bad8 commit 447a6e1

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

paddle/fluid/operators/deformable_conv_v1_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op,
307307
REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp);
308308

309309
REGISTER_OP_CPU_KERNEL(deformable_conv_v1,
310-
ops::DeformableConvV1CPUKernel<float>);
310+
ops::DeformableConvV1CPUKernel<float>,
311+
ops::DeformableConvV1CPUKernel<double>);
311312
REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad,
312-
ops::DeformableConvV1GradCPUKernel<float>);
313+
ops::DeformableConvV1GradCPUKernel<float>,
314+
ops::DeformableConvV1GradCPUKernel<double>);

paddle/fluid/operators/deformable_conv_v1_op.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ __global__ void DeformableCol2imCUDAKernel(
9999
DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
100100
cur_w + dx, height, width);
101101

102-
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
102+
platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
103+
weight * cur_top_grad);
103104
}
104105
}
105106
}
@@ -604,6 +605,8 @@ class DeformableConvV1GradCUDAKernel : public framework::OpKernel<T> {
604605
namespace ops = paddle::operators;
605606

606607
REGISTER_OP_CUDA_KERNEL(deformable_conv_v1,
607-
ops::DeformableConvV1CUDAKernel<float>);
608+
ops::DeformableConvV1CUDAKernel<float>,
609+
ops::DeformableConvV1CUDAKernel<double>);
608610
REGISTER_OP_CUDA_KERNEL(deformable_conv_v1_grad,
609-
ops::DeformableConvV1GradCUDAKernel<float>);
611+
ops::DeformableConvV1GradCUDAKernel<float>,
612+
ops::DeformableConvV1GradCUDAKernel<double>);

python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param):
108108
class TestModulatedDeformableConvOp(OpTest):
109109
def setUp(self):
110110
self.op_type = "deformable_conv_v1"
111-
self.dtype = np.float32
111+
self.init_type()
112112
self.init_group()
113113
self.init_dilation()
114114
self.init_test_case()
@@ -177,6 +177,9 @@ def init_dilation(self):
177177
def init_group(self):
178178
self.groups = 1
179179

180+
def init_type(self):
181+
self.dtype = np.float32
182+
180183

181184
class TestWithStride(TestModulatedDeformableConvOp):
182185
def init_test_case(self):
@@ -253,6 +256,11 @@ def init_group(self):
253256
self.groups = 2
254257

255258

259+
class TestWithDouble(TestModulatedDeformableConvOp):
260+
def init_type(self):
261+
self.dtype = np.float64
262+
263+
256264
class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase):
257265
def test_error(self):
258266
def test_invalid_input():

0 commit comments

Comments
 (0)