diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu index 93c9c221b26392..978bb4f01e9f83 100644 --- a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu @@ -34,9 +34,9 @@ namespace plugin { static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaximumNumBlocks = 4096; -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaximumNumBlocks); +static inline int NumBlocks(const int64_t N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + static_cast(kNumMaximumNumBlocks)); } static inline int ConvOutputSize( @@ -367,66 +367,66 @@ __device__ half DmcnIm2colBilinear(const half* bottom_data, template __global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, + const int64_t nthreads, const T* data_im, const T* data_offset, const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t num_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* data_col); template <> __global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, + const int64_t nthreads, const float* data_im, const float* data_offset, const float* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t num_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, float* data_col) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; + int64_t index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t offset = blockDim.x * static_cast(gridDim.x); float minus_one = -1.0f, height_t = height, width_t = width; - for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + for (int64_t i = index; i < nthreads; i += offset) { + const int64_t w_col = i % width_col; + const int64_t h_col = (i / width_col) % height_col; + const int64_t b_col = (i / width_col) / height_col % batch_size; + const int64_t c_im = (i / width_col / height_col) / batch_size; + const int64_t c_col = c_im * kernel_h * kernel_w; - const int deformable_group_index = c_im / channel_per_deformable_group; + const int64_t deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const int64_t h_in = h_col * stride_h - pad_h; + const int64_t w_in = w_col * stride_w - pad_w; float* data_col_ptr = data_col + @@ -440,14 +440,14 @@ __global__ void ModulatedDeformableIm2colGpuKernel( data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = + for (int64_t i = 0; i < kernel_h; ++i) { + for (int64_t j = 0; j < kernel_w; ++j) { + const int64_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const float offset_h = data_offset_ptr[data_offset_h_ptr]; @@ -471,43 +471,45 @@ __global__ void ModulatedDeformableIm2colGpuKernel( template <> __global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, + const int64_t nthreads, const half* data_im, const half* data_offset, const half* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t num_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, half* data_col) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; + int64_t index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t offset = blockDim.x * static_cast(gridDim.x); - half minus_one = -1.0f, height_t = height, width_t = width; + half minus_one = -1.0f, + height_t = static_cast(static_cast(height)), + width_t = static_cast(static_cast(width)); for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + const int64_t w_col = i % width_col; + const int64_t h_col = (i / width_col) % height_col; + const int64_t b_col = (i / width_col) / height_col % batch_size; + const int64_t c_im = (i / width_col / height_col) / batch_size; + const int64_t c_col = c_im * kernel_h * kernel_w; - const int deformable_group_index = c_im / channel_per_deformable_group; + const int64_t deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const int64_t h_in = h_col * stride_h - pad_h; + const int64_t w_in = w_col * stride_w - pad_w; half* data_col_ptr = data_col + @@ -521,21 +523,22 @@ __global__ void ModulatedDeformableIm2colGpuKernel( data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = + for (int64_t i = 0; i < kernel_h; ++i) { + for (int64_t j = 0; j < kernel_w; ++j) { + const int64_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const half offset_h = data_offset_ptr[data_offset_h_ptr]; const half offset_w = data_offset_ptr[data_offset_w_ptr]; const half mask = data_mask_ptr[data_mask_hw_ptr]; half val = 0; - half h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w; + half h_im_t = static_cast(h_in) + i * dilation_h, + w_im_t = static_cast(w_in) + j * dilation_w; const half h_im = h_im_t + offset_h; const half w_im = w_im_t + offset_w; if (h_im > minus_one && w_im > minus_one && h_im < height_t && diff --git a/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc b/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc index a8889d09aa757e..12df9d18d11465 100644 --- a/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc @@ -22,50 +22,50 @@ namespace phi { template inline void ModulatedDeformableCol2imCPUKernel( - const int num_kernels, + const int64_t num_kernels, const T* data_col, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* grad_im) { - for (int thread = 0; thread < num_kernels; thread++) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = + for (int64_t thread = 0; thread < num_kernels; thread++) { + const int64_t j = (thread / width_col / height_col / batch_size) % kernel_w; + const int64_t i = (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = + const int64_t c = thread / width_col / height_col / batch_size / kernel_w / kernel_h; - const int deformable_group_index = c / channel_per_deformable_group; + const int64_t deformable_group_index = c / channel_per_deformable_group; - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; + int64_t w_out = thread % width_col; + int64_t h_out = (thread / width_col) % height_col; + int64_t b = (thread / width_col / height_col) % batch_size; + int64_t w_in = w_out * stride_w - pad_w; + int64_t h_in = h_out * stride_h - pad_h; const T* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = + const int64_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const T offset_h = data_offset_ptr[data_offset_h_ptr]; const T offset_w = data_offset_ptr[data_offset_w_ptr]; @@ -80,14 +80,14 @@ inline void ModulatedDeformableCol2imCPUKernel( const T mask = data_mask_ptr[data_mask_hw_ptr]; cur_top_grad *= mask; } - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { + const int64_t cur_h = static_cast(cur_inv_h_data); + const int64_t cur_w = static_cast(cur_inv_w_data); + for (int64_t dy = -2; dy <= 2; dy++) { + for (int64_t dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = + int64_t cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; T weight = DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, @@ -117,10 +117,10 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, const std::vector& dilation, const int deformable_group, T* grad_im) { - int channel_per_deformable_group = - static_cast(im_shape[0] / deformable_group); - int num_kernels = static_cast(col_shape[0] * col_shape[1] * - col_shape[2] * col_shape[3]); + int64_t channel_per_deformable_group = + static_cast(im_shape[0] / deformable_group); + int64_t num_kernels = + col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; ModulatedDeformableCol2imCPUKernel(num_kernels, data_col, @@ -147,40 +147,40 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, template void ModulatedDeformableCol2imCoordCPUKernel( - const int num_kernels, + const int64_t num_kernels, const T* data_col, const T* data_im, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int offset_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t offset_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* grad_offset, T* grad_mask) { - for (int i = 0; i < num_kernels; i++) { + for (int64_t i = 0; i < num_kernels; i++) { T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; + const int64_t w = i % width_col; + const int64_t h = (i / width_col) % height_col; + const int64_t c = (i / width_col / height_col) % offset_channels; + const int64_t b = (i / width_col / height_col) / offset_channels; - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; + const int64_t deformable_group_index = c / (2 * kernel_h * kernel_w); + const int64_t col_step = kernel_h * kernel_w; + int64_t cnt = 0; const T* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; @@ -197,24 +197,25 @@ void ModulatedDeformableCol2imCoordCPUKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + const int64_t offset_c = + c - deformable_group_index * 2 * kernel_h * kernel_w; - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + for (int64_t col_c = offset_c / 2; col_c < channel_per_deformable_group; col_c += col_step) { - const int col_pos = + const int64_t col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; + const int64_t bp_dir = offset_c % 2; - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = + int64_t j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int64_t i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = + int64_t w_out = col_pos % width_col; + int64_t h_out = (col_pos / width_col) % height_col; + int64_t w_in = w_out * stride_w - pad_w; + int64_t h_in = h_out * stride_h - pad_h; + const int64_t data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const T offset_h = data_offset_ptr[data_offset_h_ptr]; @@ -241,7 +242,7 @@ void ModulatedDeformableCol2imCoordCPUKernel( width, bp_dir); if (data_mask_ptr) { - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const T mask = data_mask_ptr[data_mask_hw_ptr]; val += weight * data_col_ptr[col_pos] * mask; @@ -312,13 +313,13 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, template void FilterGradAddup(const Context& dev_ctx, - const int nthreads, - const int n, - const int height, - const int width, + const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad) { - for (int i = 0; i < nthreads; i++) { + for (int64_t i = 0; i < nthreads; i++) { filter_grad[i] = filter_grad[i] + dweight_3d[i]; } } diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.cc b/paddle/phi/kernels/funcs/deformable_conv_functor.cc index e028b51e3ce7b9..49016f148fb62b 100644 --- a/paddle/phi/kernels/funcs/deformable_conv_functor.cc +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.cc @@ -20,38 +20,38 @@ namespace phi::funcs { template inline void ModulatedDeformableIm2colCPUKernel( - const int num_kernels, + const int64_t num_kernels, const T* data_im, const T* data_offset, const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t num_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* data_col) { - for (int i = 0; i < num_kernels; i++) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + for (int64_t i = 0; i < num_kernels; i++) { + const int64_t w_col = i % width_col; + const int64_t h_col = (i / width_col) % height_col; + const int64_t b_col = (i / width_col) / height_col % batch_size; + const int64_t c_im = (i / width_col / height_col) / batch_size; + const int64_t c_col = c_im * kernel_h * kernel_w; - const int deformable_group_index = c_im / channel_per_deformable_group; + const int64_t deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const int64_t h_in = h_col * stride_h - pad_h; + const int64_t w_in = w_col * stride_w - pad_w; T* data_col_ptr = data_col + @@ -67,11 +67,11 @@ inline void ModulatedDeformableIm2colCPUKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = + for (int64_t i = 0; i < kernel_h; ++i) { + for (int64_t j = 0; j < kernel_w; ++j) { + const int64_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; @@ -86,7 +86,7 @@ inline void ModulatedDeformableIm2colCPUKernel( } *data_col_ptr = val; if (data_mask_ptr) { - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const T mask = data_mask_ptr[data_mask_hw_ptr]; *data_col_ptr *= mask; @@ -110,10 +110,9 @@ void ModulatedDeformableIm2col(const Context& dev_ctx UNUSED, const std::vector& dilations, const int deformable_groups, T* data_col) { - int channel_per_deformable_group = - static_cast(im_shape[0] / deformable_groups); - int num_kernels = static_cast(im_shape[0] * col_shape[1] * col_shape[2] * - col_shape[3]); + int64_t channel_per_deformable_group = im_shape[0] / deformable_groups; + int64_t num_kernels = + im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; // get outputs of im2col with offset by bilinear interpolation ModulatedDeformableIm2colCPUKernel(num_kernels, diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.cu b/paddle/phi/kernels/funcs/deformable_conv_functor.cu index 48105d1f517e9b..c1ca11eb33e5cc 100644 --- a/paddle/phi/kernels/funcs/deformable_conv_functor.cu +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.cu @@ -21,47 +21,47 @@ namespace funcs { static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaximumNumBlocks = 4096; -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaximumNumBlocks); +static inline int NumBlocks(const int64_t N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + static_cast(kNumMaximumNumBlocks)); } template __global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, + const int64_t nthreads, const T* data_im, const T* data_offset, const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t num_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* data_col) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t offset = blockDim.x * gridDim.x; for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + const int64_t w_col = i % width_col; + const int64_t h_col = (i / width_col) % height_col; + const int64_t b_col = (i / width_col) / height_col % batch_size; + const int64_t c_im = (i / width_col / height_col) / batch_size; + const int64_t c_col = c_im * kernel_h * kernel_w; - const int deformable_group_index = c_im / channel_per_deformable_group; + const int64_t deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const int64_t h_in = h_col * stride_h - pad_h; + const int64_t w_in = w_col * stride_w - pad_w; T* data_col_ptr = data_col + @@ -77,11 +77,11 @@ __global__ void ModulatedDeformableIm2colGpuKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = + for (int64_t i = 0; i < kernel_h; ++i) { + for (int64_t j = 0; j < kernel_w; ++j) { + const int64_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; @@ -96,7 +96,7 @@ __global__ void ModulatedDeformableIm2colGpuKernel( } *data_col_ptr = val; if (data_mask_ptr) { - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const T mask = data_mask_ptr[data_mask_hw_ptr]; *data_col_ptr *= mask; @@ -120,11 +120,12 @@ void ModulatedDeformableIm2col(const Context& dev_ctx, const std::vector& dilations, const int deformable_groups, T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int64_t channel_per_deformable_group = im_shape[0] / deformable_groups; + int64_t num_kernels = + im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; + int64_t blocks = NumBlocks(num_kernels); + int64_t threads = kNumCUDAThreads; ModulatedDeformableIm2colGpuKernel <<>>(num_kernels, diff --git a/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu b/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu index 55c8a9f96fd818..c6e54b070a6ab5 100644 --- a/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu @@ -31,52 +31,52 @@ static inline int NumBlocks(const int N) { template __global__ void ModulatedDeformableCol2imGpuKernel( - const int nthreads, + const int64_t nthreads, const T* data_col, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* grad_im) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t thread = index; thread < nthreads; thread += offset) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = + int64_t index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t offset = static_cast(blockDim.x) * gridDim.x; + for (int64_t thread = index; thread < nthreads; thread += offset) { + const int64_t j = (thread / width_col / height_col / batch_size) % kernel_w; + const int64_t i = (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = + const int64_t c = thread / width_col / height_col / batch_size / kernel_w / kernel_h; - const int deformable_group_index = c / channel_per_deformable_group; + const int64_t deformable_group_index = c / channel_per_deformable_group; - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; + int64_t w_out = thread % width_col; + int64_t h_out = (thread / width_col) % height_col; + int64_t b = (thread / width_col / height_col) % batch_size; + int64_t w_in = w_out * stride_w - pad_w; + int64_t h_in = h_out * stride_h - pad_h; const T* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = + const int64_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const T offset_h = data_offset_ptr[data_offset_h_ptr]; const T offset_w = data_offset_ptr[data_offset_w_ptr]; @@ -91,14 +91,14 @@ __global__ void ModulatedDeformableCol2imGpuKernel( const T mask = data_mask_ptr[data_mask_hw_ptr]; cur_top_grad *= mask; } - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { + const int64_t cur_h = static_cast(cur_inv_h_data); + const int64_t cur_w = static_cast(cur_inv_w_data); + for (int64_t dy = -2; dy <= 2; dy++) { + for (int64_t dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = + int64_t cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; T weight = DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, @@ -128,10 +128,11 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, const std::vector& dilation, const int deformable_group, T* grad_im) { - int channel_per_deformable_group = im_shape[0] / deformable_group; - int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; + int64_t channel_per_deformable_group = im_shape[0] / deformable_group; + int64_t num_kernels = + col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int64_t blocks = NumBlocks(num_kernels); + int64_t threads = kNumCUDAThreads; ModulatedDeformableCol2imGpuKernel <<>>(num_kernels, @@ -159,42 +160,42 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, template __global__ void ModulatedDeformableCol2imCoordGpuKernel( - const int nthreads, + const int64_t nthreads, const T* data_col, const T* data_im, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int offset_channels, - const int deformable_group, - const int height_col, - const int width_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channel_per_deformable_group, + const int64_t batch_size, + const int64_t offset_channels, + const int64_t deformable_group, + const int64_t height_col, + const int64_t width_col, T* grad_offset, T* grad_mask) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { + int64_t index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t offset = blockDim.x * static_cast(gridDim.x); + for (int64_t i = index; i < nthreads; i += offset) { T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; + const int64_t w = i % width_col; + const int64_t h = (i / width_col) % height_col; + const int64_t c = (i / width_col / height_col) % offset_channels; + const int64_t b = (i / width_col / height_col) / offset_channels; - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; + const int64_t deformable_group_index = c / (2 * kernel_h * kernel_w); + const int64_t col_step = kernel_h * kernel_w; + int64_t cnt = 0; const T* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; @@ -211,24 +212,25 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + const int64_t offset_c = + c - deformable_group_index * 2 * kernel_h * kernel_w; - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + for (int64_t col_c = offset_c / 2; col_c < channel_per_deformable_group; col_c += col_step) { - const int col_pos = + const int64_t col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; + const int64_t bp_dir = offset_c % 2; - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = + int64_t j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int64_t i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = + int64_t w_out = col_pos % width_col; + int64_t h_out = (col_pos / width_col) % height_col; + int64_t w_in = w_out * stride_w - pad_w; + int64_t h_in = h_out * stride_h - pad_h; + const int64_t data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = + const int64_t data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const T offset_h = data_offset_ptr[data_offset_h_ptr]; @@ -255,7 +257,7 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( width, bp_dir); if (data_mask_ptr) { - const int data_mask_hw_ptr = + const int64_t data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const T mask = data_mask_ptr[data_mask_hw_ptr]; val += weight * data_col_ptr[col_pos] * mask; @@ -291,11 +293,11 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, const int deformable_groups, T* grad_offset, T* grad_mask) { - int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups; - int channel_per_deformable_group = col_shape[0] / deformable_groups; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; + int64_t num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int64_t channel_per_deformable_group = col_shape[0] / deformable_groups; + int64_t blocks = NumBlocks(num_kernels); + int64_t threads = kNumCUDAThreads; ModulatedDeformableCol2imCoordGpuKernel <<>>( @@ -326,29 +328,32 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, } template -__global__ void FilterGradAddupGpuKernel(const int nthreads, - const int n, - const int height, - const int width, +__global__ void FilterGradAddupGpuKernel(const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { + int64_t index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t offset = blockDim.x * static_cast(gridDim.x); + for (int64_t i = index; i < nthreads; i += offset) { filter_grad[i] = filter_grad[i] + dweight_3d[i]; } } template void FilterGradAddup(const Context& dev_ctx, - const int nthreads, - const int n, - const int height, - const int width, + const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad) { + const int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize()[0]; + const int64_t grid_size = std::min( + (nthreads + kNumCUDAThreads - 1) / kNumCUDAThreads, max_grid_x); FilterGradAddupGpuKernel - <<>>( + <<>>( nthreads, n, height, width, dweight_3d, filter_grad); } diff --git a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h index fe2107e52af7f6..4b0f7e5e8a180c 100644 --- a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h @@ -151,10 +151,10 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, template void FilterGradAddup(const Context& dev_ctx, - const int nthreads, - const int n, - const int height, - const int width, + const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad); @@ -241,9 +241,9 @@ void DeformableConvGradKernel(const Context& dev_ctx, phi::funcs::SetConstant set_zero; auto blas = phi::funcs::GetBlas(dev_ctx); - int input_dim = x.numel() / x.dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; + int64_t input_dim = x.numel() / x.dims()[0]; + int64_t input_offset_dim = offset.numel() / offset.dims()[0]; + int64_t input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; if (filter_grad) { Full(dev_ctx, diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h index 8fcf4bf0f38700..43d68a85380f3f 100644 --- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h @@ -38,14 +38,15 @@ void DeformableConvKernel(const Context& dev_ctx, int groups, int im2col_step, DenseTensor* out) { + const int64_t batch_size = static_cast(x.dims()[0]); + if (x.numel() == 0 || filter.numel() == 0) { phi::Full( dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); return; } - const int batch_size = static_cast(x.dims()[0]); - int temp_step = std::min(64, batch_size); + int64_t temp_step = std::min(64, batch_size); if (batch_size % temp_step == 0) { im2col_step = temp_step; } @@ -86,9 +87,9 @@ void DeformableConvKernel(const Context& dev_ctx, DDim input_shape = common::slice_ddim(x.dims(), 1, x.dims().size()); std::vector input_shape_vec = common::vectorize(input_shape); - int input_dim = x.numel() / x.dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; + int64_t input_dim = x.numel() / x.dims()[0]; + int64_t input_offset_dim = offset.numel() / offset.dims()[0]; + int64_t input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; const T* input_ptr = x.data(); const T* offset_ptr = offset.data(); @@ -97,7 +98,7 @@ void DeformableConvKernel(const Context& dev_ctx, auto blas = phi::funcs::GetBlas(dev_ctx); - for (int i = 0; i < batch_size / im2col_step; ++i) { + for (int64_t i = 0; i < batch_size / im2col_step; ++i) { const T* temp_mask_ptr = mask_ptr ? mask_ptr + i * im2col_step * input_mask_dim : nullptr; funcs::ModulatedDeformableIm2col( @@ -139,7 +140,6 @@ void DeformableConvKernel(const Context& dev_ctx, T(0.0)); } } - // swap axis to get the right result when im2col_step is greater than 1 if (im2col_step > 1) { std::vector axis(4);