-
Notifications
You must be signed in to change notification settings - Fork 6k
Speedup roi_perspective_transform op by caching the information of linear interpolation in forward #17090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup roi_perspective_transform op by caching the information of linear interpolation in forward #17090
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ limitations under the License. */ | |
|
|
||
| #include <algorithm> | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/operators/math/math_function.h" | ||
| #include "paddle/fluid/platform/cuda_primitives.h" | ||
| #include "paddle/fluid/platform/float16.h" | ||
|
|
||
|
|
@@ -115,8 +116,9 @@ __device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) { | |
| template <typename T> | ||
| __device__ void bilinear_interpolate(const T* in_data, const int channels, | ||
| const int width, const int height, | ||
| int in_n, int in_c, T in_w, T in_h, | ||
| T* val) { | ||
| int in_n, int in_c, T in_w, T in_h, T* val, | ||
| int out_idx, int* out2in_idx, | ||
| T* out2in_w) { | ||
| // Deal with cases that source coords are out of feature map boundary | ||
| if (GT<T>(-0.5, in_w) || GT<T>(in_w, width - 0.5) || GT<T>(-0.5, in_h) || | ||
| GT<T>(in_h, height - 0.5)) { | ||
|
|
@@ -165,6 +167,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels, | |
| T w3 = w_floor * h_floor; | ||
| T w4 = w_floor * h_ceil; | ||
| val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; | ||
|
|
||
| int base_idx = (in_n * channels + in_c) * height * width; | ||
| out2in_idx[out_idx * 4] = base_idx + in_h_floor * width + in_w_floor; | ||
| out2in_idx[out_idx * 4 + 1] = base_idx + in_h_ceil * width + in_w_floor; | ||
| out2in_idx[out_idx * 4 + 2] = base_idx + in_h_ceil * width + in_w_ceil; | ||
| out2in_idx[out_idx * 4 + 3] = base_idx + in_h_floor * width + in_w_ceil; | ||
| out2in_w[out_idx * 4] = w1; | ||
| out2in_w[out_idx * 4 + 1] = w2; | ||
| out2in_w[out_idx * 4 + 2] = w3; | ||
| out2in_w[out_idx * 4 + 3] = w4; | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -262,13 +274,11 @@ __device__ void get_transform_matrix(const int transformed_width, | |
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void RoiTransformKernel(const float* input_data, | ||
| const float* rois_data, | ||
| const int* roi2image_data, int num_rois, | ||
| int in_height, int in_width, int channels, | ||
| int transformed_height, | ||
| int transformed_width, float spatial_scale, | ||
| T* output_data) { | ||
| __global__ void RoiTransformKernel( | ||
| const float* input_data, const float* rois_data, const int* roi2image_data, | ||
| int num_rois, int in_height, int in_width, int channels, | ||
| int transformed_height, int transformed_width, float spatial_scale, | ||
| T* output_data, int* out2in_idx, T* out2in_w) { | ||
| int output_size = | ||
| num_rois * transformed_height * transformed_width * channels; | ||
|
|
||
|
|
@@ -311,7 +321,8 @@ __global__ void RoiTransformKernel(const float* input_data, | |
| // Perform bilinear interpolation | ||
| int in_n = roi2image_data[n]; | ||
| bilinear_interpolate<T>(input_data, channels, in_width, in_height, in_n, | ||
| c, in_w, in_h, output_data + index); | ||
| c, in_w, in_h, output_data + index, index, | ||
| out2in_idx, out2in_w); | ||
| } | ||
|
|
||
| } else { | ||
|
|
@@ -328,6 +339,16 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> { | |
| auto* in = ctx.Input<framework::Tensor>("X"); | ||
| auto* rois = ctx.Input<framework::LoDTensor>("ROIs"); | ||
| auto* out = ctx.Output<framework::Tensor>("Out"); | ||
| auto* out2in_idx = ctx.Output<framework::Tensor>("Out2InIdx"); | ||
| auto* out2in_w = ctx.Output<framework::Tensor>("Out2InWeights"); | ||
|
|
||
| int* out2in_idx_data = | ||
| out2in_idx->mutable_data<int>({out->numel(), 4}, ctx.GetPlace()); | ||
| T* out2in_w_data = | ||
| out2in_w->mutable_data<T>({out->numel(), 4}, ctx.GetPlace()); | ||
|
|
||
| math::SetConstant<platform::CUDADeviceContext, int> init; | ||
| init(ctx.cuda_device_context(), out2in_idx, static_cast<int>(-1)); | ||
|
|
||
| auto transformed_height = ctx.Attr<int>("transformed_height"); | ||
| auto transformed_width = ctx.Attr<int>("transformed_width"); | ||
|
|
@@ -364,7 +385,7 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> { | |
| RoiTransformKernel<T><<<grid, block, 0, stream>>>( | ||
| input_data, rois_data, roi2image_dev.data<int>(), rois_num, in_height, | ||
| in_width, channels, transformed_height, transformed_width, | ||
| spatial_scale, output_data); | ||
| spatial_scale, output_data, out2in_idx_data, out2in_w_data); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -420,100 +441,42 @@ __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, | |
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void RoiTransformGradKernel( | ||
| const size_t* lod, const T* rois_data, int batch_size, int num_rois, | ||
| int in_height, int in_width, int channels, int transformed_height, | ||
| int transformed_width, float spatial_scale, const T* out_grad_data, | ||
| T* in_grad_data) { | ||
| int input_size = batch_size * in_height * in_width * channels; | ||
|
|
||
| CUDA_1D_KERNEL_LOOP(index, input_size) { | ||
| // (n, c, h, w) coords in input | ||
| int in_w = idx4_4(index, batch_size, channels, in_height, in_width); | ||
| int in_h = idx4_3(index, batch_size, channels, in_height, in_width); | ||
| int c = idx4_2(index, batch_size, channels, in_height, in_width); | ||
| int n = idx4_1(index, batch_size, channels, in_height, in_width); | ||
|
|
||
| T gradient = 0.0; | ||
| // Accumulate gradient over all RoIs that interpolated this element | ||
| for (size_t roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { | ||
| const T* rois = rois_data + roi_idx * 8; | ||
| T roi_x[4]; | ||
| T roi_y[4]; | ||
| for (int k = 0; k < 4; ++k) { | ||
| roi_x[k] = rois[2 * k] * spatial_scale; | ||
| roi_y[k] = rois[2 * k + 1] * spatial_scale; | ||
| } | ||
|
|
||
| // Get transform matrix | ||
| T matrix[9]; | ||
| get_transform_matrix<T>(transformed_width, transformed_height, roi_x, | ||
| roi_y, matrix); | ||
|
|
||
| const T* out_grad_ptr = | ||
| out_grad_data + | ||
| (roi_idx * channels + c) * transformed_height * transformed_width; | ||
| for (int out_h = 0; out_h < transformed_height; ++out_h) { | ||
| for (int out_w = 0; out_w < transformed_width; ++out_w) { | ||
| T src_w; | ||
| T src_h; | ||
| get_source_coords<T>(matrix, out_w, out_h, &src_w, &src_h); | ||
| if (in_quad<T>(src_w, src_h, roi_x, roi_y)) { | ||
| if (GT<T>(-0.5, src_w) || | ||
| GT<T>(src_w, static_cast<T>(in_width - 0.5)) || | ||
| GT<T>(-0.5, src_h) || | ||
| GT<T>(src_h, static_cast<T>(in_height - 0.5))) { | ||
| continue; | ||
| } | ||
| T weight = get_feature_gradient<T>(src_w, src_h, in_w, in_h, | ||
| in_width, in_height); | ||
| gradient += | ||
| out_grad_ptr[out_h * transformed_width + out_w] * weight; | ||
| } | ||
| } | ||
| } | ||
| __global__ void RoiTransformGradKernel(int out_size, const int* out2in_idx_data, | ||
| const T* out2in_w_data, | ||
| const T* out_grad_data, | ||
| T* in_grad_data) { | ||
| CUDA_1D_KERNEL_LOOP(index, out_size * 4) { | ||
| int in_idx = out2in_idx_data[index]; | ||
| if (in_idx > 0) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be >= here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx. Fixed. |
||
| int out_idx = index / 4; | ||
| atomicAdd(in_grad_data + in_idx, | ||
| out_grad_data[out_idx] * out2in_w_data[index]); | ||
| } | ||
| in_grad_data[index] = gradient; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||
| auto* in = ctx.Input<framework::Tensor>("X"); | ||
| auto* rois = ctx.Input<framework::LoDTensor>("ROIs"); | ||
| auto* out2in_idx = ctx.Input<framework::LoDTensor>("Out2InIdx"); | ||
| auto* out2in_w = ctx.Input<framework::LoDTensor>("Out2InWeights"); | ||
| auto* out_grad = | ||
| ctx.Input<framework::Tensor>(framework::GradVarName("Out")); | ||
| auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
|
|
||
| auto transformed_height = ctx.Attr<int>("transformed_height"); | ||
| auto transformed_width = ctx.Attr<int>("transformed_width"); | ||
| auto spatial_scale = ctx.Attr<float>("spatial_scale"); | ||
|
|
||
| auto in_dims = in->dims(); | ||
| int batch_size = in_dims[0]; | ||
| int channels = in_dims[1]; | ||
| int in_height = in_dims[2]; | ||
| int in_width = in_dims[3]; | ||
| int rois_num = rois->dims()[0]; | ||
|
|
||
| T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); | ||
| const T* out_grad_data = out_grad->data<T>(); | ||
| const T* rois_data = rois->data<T>(); | ||
|
|
||
| auto lod = rois->lod().back(); | ||
| auto lod_data = lod.CUDAData(ctx.GetPlace()); | ||
| const int* out2in_idx_data = out2in_idx->data<int>(); | ||
| const T* out2in_w_data = out2in_w->data<T>(); | ||
|
|
||
| int in_size = in->numel(); | ||
| int out_size = out_grad->numel(); | ||
| auto stream = ctx.cuda_device_context().stream(); | ||
| int block = 512; | ||
| int grid = (in_size + block - 1) / block; | ||
| int grid = (out_size * 4 + block - 1) / block; | ||
|
|
||
| RoiTransformGradKernel<T><<<grid, block, 0, stream>>>( | ||
| lod_data, rois_data, batch_size, rois_num, in_height, in_width, | ||
| channels, transformed_height, transformed_width, spatial_scale, | ||
| out_grad_data, in_grad_data); | ||
| out_size, out2in_idx_data, out2in_w_data, out_grad_data, in_grad_data); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -299,6 +299,10 @@ def test_check_output(self): | |
| self.check_output() | ||
|
|
||
| def test_check_grad(self): | ||
| self.outputs['Out2InIdx'] = np.zeros( | ||
| [np.product(self.outputs['Out'].shape), 4]).astype("int32") | ||
| self.outputs['Out2InWeights'] = np.zeros( | ||
| [np.product(self.outputs['Out'].shape), 4]).astype("float32") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test_roi_pooling里面test_checkout()也会check Intermediate的output,此处是不是可以修改一下
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. roi_pooling的infer shape里有对 Intermediate output的ENFORCE CHECK,所以单测的test_checkout()也需要有Intermediate output。 当前pr只修改了CUDA kernel, CPU kernel的计算没有用到Intermediate output, 所以单测里也没有check. |
||
| self.check_grad(['X'], 'Out') | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out2in_w 需要初始化么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据这行的条件:
如果out2in_idx[i] == -1, 则out2in_w[i]就不会被用到
所以只需要将out2in_idx中元素初始化为-1即可