Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 90 additions & 87 deletions paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ namespace plugin {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
static inline int NumBlocks(const int64_t N) {
return std::min<int64_t>((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
static_cast<int64_t>(kNumMaximumNumBlocks));
}

static inline int ConvOutputSize(
Expand Down Expand Up @@ -367,66 +367,66 @@ __device__ half DmcnIm2colBilinear<half>(const half* bottom_data,

template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel(
const int nthreads,
const int64_t nthreads,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
const int64_t height,
const int64_t width,
const int64_t kernel_h,
const int64_t kernel_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t stride_h,
const int64_t stride_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t channel_per_deformable_group,
const int64_t batch_size,
const int64_t num_channels,
const int64_t deformable_group,
const int64_t height_col,
const int64_t width_col,
T* data_col);

template <>
__global__ void ModulatedDeformableIm2colGpuKernel<float>(
const int nthreads,
const int64_t nthreads,
const float* data_im,
const float* data_offset,
const float* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
const int64_t height,
const int64_t width,
const int64_t kernel_h,
const int64_t kernel_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t stride_h,
const int64_t stride_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t channel_per_deformable_group,
const int64_t batch_size,
const int64_t num_channels,
const int64_t deformable_group,
const int64_t height_col,
const int64_t width_col,
float* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
int64_t index = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
int64_t offset = blockDim.x * static_cast<int64_t>(gridDim.x);

float minus_one = -1.0f, height_t = height, width_t = width;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
for (int64_t i = index; i < nthreads; i += offset) {
const int64_t w_col = i % width_col;
const int64_t h_col = (i / width_col) % height_col;
const int64_t b_col = (i / width_col) / height_col % batch_size;
const int64_t c_im = (i / width_col / height_col) / batch_size;
const int64_t c_col = c_im * kernel_h * kernel_w;

const int deformable_group_index = c_im / channel_per_deformable_group;
const int64_t deformable_group_index = c_im / channel_per_deformable_group;

const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
const int64_t h_in = h_col * stride_h - pad_h;
const int64_t w_in = w_col * stride_w - pad_w;

float* data_col_ptr =
data_col +
Expand All @@ -440,14 +440,14 @@ __global__ void ModulatedDeformableIm2colGpuKernel<float>(
data_mask + (b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;

for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
for (int64_t i = 0; i < kernel_h; ++i) {
for (int64_t j = 0; j < kernel_w; ++j) {
const int64_t data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
const int64_t data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
const int64_t data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;

const float offset_h = data_offset_ptr[data_offset_h_ptr];
Expand All @@ -471,43 +471,45 @@ __global__ void ModulatedDeformableIm2colGpuKernel<float>(

template <>
__global__ void ModulatedDeformableIm2colGpuKernel<half>(
const int nthreads,
const int64_t nthreads,
const half* data_im,
const half* data_offset,
const half* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
const int64_t height,
const int64_t width,
const int64_t kernel_h,
const int64_t kernel_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t stride_h,
const int64_t stride_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t channel_per_deformable_group,
const int64_t batch_size,
const int64_t num_channels,
const int64_t deformable_group,
const int64_t height_col,
const int64_t width_col,
half* data_col) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
int64_t index = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
int64_t offset = blockDim.x * static_cast<int64_t>(gridDim.x);

half minus_one = -1.0f, height_t = height, width_t = width;
half minus_one = -1.0f,
height_t = static_cast<half>(static_cast<float>(height)),
width_t = static_cast<half>(static_cast<float>(width));
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int64_t w_col = i % width_col;
const int64_t h_col = (i / width_col) % height_col;
const int64_t b_col = (i / width_col) / height_col % batch_size;
const int64_t c_im = (i / width_col / height_col) / batch_size;
const int64_t c_col = c_im * kernel_h * kernel_w;

const int deformable_group_index = c_im / channel_per_deformable_group;
const int64_t deformable_group_index = c_im / channel_per_deformable_group;

const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
const int64_t h_in = h_col * stride_h - pad_h;
const int64_t w_in = w_col * stride_w - pad_w;

half* data_col_ptr =
data_col +
Expand All @@ -521,21 +523,22 @@ __global__ void ModulatedDeformableIm2colGpuKernel<half>(
data_mask + (b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;

for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
for (int64_t i = 0; i < kernel_h; ++i) {
for (int64_t j = 0; j < kernel_w; ++j) {
const int64_t data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
const int64_t data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
const int64_t data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;

const half offset_h = data_offset_ptr[data_offset_h_ptr];
const half offset_w = data_offset_ptr[data_offset_w_ptr];
const half mask = data_mask_ptr[data_mask_hw_ptr];
half val = 0;
half h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w;
half h_im_t = static_cast<float>(h_in) + i * dilation_h,
w_im_t = static_cast<float>(w_in) + j * dilation_w;
const half h_im = h_im_t + offset_h;
const half w_im = w_im_t + offset_w;
if (h_im > minus_one && w_im > minus_one && h_im < height_t &&
Expand Down
Loading