-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimization of index_select op backward #32955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Xreki
merged 30 commits into
PaddlePaddle:develop
from
Zjq9409:Optimizaition_of_index_select_backward_cpu_op
Jul 20, 2021
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
3a1c466
optimization of index_select op backward
Zjq9409 7ba619f
optimization of index_select op backward
Zjq9409 090e046
add compile parameter
Zjq9409 a2b54bb
optimization of index_select op backward
Zjq9409 cdd297b
optimization of index_select op backward
Zjq9409 3a6faf1
optimization of index_select op backward
Zjq9409 252eefd
optimization of index_select op backward
Zjq9409 aa8057d
optimization of index_select op backward
Zjq9409 92ff317
optimization of index_select op backward
Zjq9409 9beb43b
optimization of index_select op backward
Zjq9409 d2f9aa8
optimization of index_select op backward
Zjq9409 d121f02
optimization of index_select op backward
Zjq9409 de92b58
optimization of index_select backward
Zjq9409 0efa130
optimization of index_select op backward
Zjq9409 3b62a64
optimization of index_select op backward
Zjq9409 f3c1fb0
optimization of index_select op backward
Zjq9409 c76a2da
optimization of index_select op backward
Zjq9409 72b1c71
Merge branch 'develop' into Optimizaition_of_index_select_backward_cp…
Zjq9409 7c003b1
optimization of index_select op backward
Zjq9409 8a47e37
optimization of index_select op
Zjq9409 3555551
optimization of index_select backward
Zjq9409 1cba37c
Optimizaition of index select backward by blas
Zjq9409 3bdc32c
Optimizaition of index select backward by blas
Zjq9409 c035b19
modify add operator to blas
Zjq9409 cfb18c7
Merge branch 'develop' into Optimizaition_of_index_select_backward_cp…
Zjq9409 b341849
optimization index select backward
Zjq9409 f7ee626
modify template
Zjq9409 e6536f0
modify template
Zjq9409 04afbfd
optimization of index select
Zjq9409 62c570f
optimization index_select backward
Zjq9409 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,10 @@ | |
| #pragma once | ||
| #include <vector> | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/operators/jit/macro.h" | ||
| #include "paddle/fluid/operators/math/blas.h" | ||
| #include "paddle/fluid/operators/math/math_function.h" | ||
| #include "paddle/fluid/platform/cpu_info.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
@@ -38,7 +42,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, | |
|
|
||
| auto input_width = slice_size * input_dim[dim]; | ||
| auto output_width = slice_size * output_dim[dim]; | ||
|
|
||
| auto outer_nums = 1; | ||
| for (auto i = 0; i < dim; i++) { | ||
| outer_nums *= input_dim[i]; | ||
|
|
@@ -77,7 +80,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, | |
| for (auto i = 0; i < outer_nums; i++) { | ||
| auto input_start_offset = i * input_width; | ||
| auto output_start_offset = i * output_width; | ||
|
|
||
| for (auto j = 0; j < index_size; j++) { | ||
| IndexT index_value = index_vec[j]; | ||
| for (auto k = 0; k < slice_size; k++) { | ||
|
|
@@ -98,7 +100,6 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| auto* inputs_var = context.InputVar("X"); | ||
| auto* index_var = context.InputVar("Index"); | ||
| auto* output_var = context.OutputVar("Out"); | ||
|
|
||
| auto& inputs = inputs_var->Get<LoDTensor>(); | ||
| auto& index = index_var->Get<LoDTensor>(); | ||
| auto* output = output_var->GetMutable<framework::LoDTensor>(); | ||
|
|
@@ -107,8 +108,8 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| if (dim < 0) { | ||
| dim += inputs.dims().size(); | ||
| } | ||
|
|
||
| const auto& index_type = index.type(); | ||
|
|
||
| bool index_type_match = index_type == framework::proto::VarType::INT32 || | ||
| index_type == framework::proto::VarType::INT64; | ||
| PADDLE_ENFORCE_EQ(index_type_match, true, | ||
|
|
@@ -129,19 +130,41 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| } | ||
| }; | ||
|
|
||
| template <typename T, typename IndexT = int> | ||
| template <typename DeviceContext, typename T, class Enable = void> | ||
| struct IndexSelectAdd { | ||
| void operator()(const framework::ExecutionContext& ctx, int slice_size, | ||
| const T* src_pointer, const T* p_pointer, T* dist_pointer) { | ||
| for (int i = 0; i < slice_size; i++) { | ||
| dist_pointer[i] = src_pointer[i] + p_pointer[i]; | ||
| } | ||
| } | ||
| }; | ||
| template <typename DeviceContext, typename T> | ||
| struct IndexSelectAdd< | ||
| DeviceContext, T, | ||
| typename std::enable_if<std::is_floating_point<T>::value>::type> { | ||
| void operator()(const framework::ExecutionContext& ctx, int slice_size, | ||
| const T* src_pointer, const T* p_pointer, T* dist_pointer) { | ||
| auto blas = math::GetBlas<DeviceContext, T>(ctx); | ||
| blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); | ||
| } | ||
| }; | ||
|
|
||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| template <typename DeviceContext, typename T, typename IndexT = int> | ||
| void IndexSelectGradInner(const framework::ExecutionContext& context, | ||
| const LoDTensor& out_grad, const LoDTensor& index, | ||
| const LoDTensor* out_grad, const LoDTensor* index, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要修改参数的类型,不用修改的输入用 |
||
| LoDTensor* x_grad, int dim) { | ||
| std::vector<T> input_vec; | ||
| std::vector<IndexT> index_vec; | ||
| TensorToVector(out_grad, context.device_context(), &input_vec); | ||
| TensorToVector(index, context.device_context(), &index_vec); | ||
|
|
||
| auto input_dim = out_grad.dims(); | ||
| const T* input_data = out_grad->data<T>(); | ||
| const IndexT* index_data = index->data<IndexT>(); | ||
| const T* p_output = x_grad->mutable_data<T>(context.GetPlace()); | ||
| T* out_data = x_grad->mutable_data<T>(context.GetPlace()); | ||
| auto input_dim = out_grad->dims(); | ||
| auto input_dim_size = input_dim.size(); | ||
| auto output_dim = x_grad->dims(); | ||
| std::vector<T> out_vec(x_grad->numel(), 0); | ||
|
|
||
| auto& dev_ctx = context.template device_context<DeviceContext>(); | ||
| math::SetConstant<DeviceContext, T> set_constant; | ||
| set_constant(dev_ctx, x_grad, static_cast<T>(0.0)); | ||
|
|
||
| auto slice_size = 1; | ||
| for (auto i = dim + 1; i < input_dim_size; i++) { | ||
|
|
@@ -156,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, | |
| outer_nums *= input_dim[i]; | ||
| } | ||
|
|
||
| auto index_size = index.dims()[0]; | ||
| auto index_size = index->dims()[0]; | ||
| VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums | ||
| << "; slice_size: " << slice_size << "; input_width: " << input_width | ||
| << "; output_width: " << output_width | ||
|
|
@@ -167,35 +190,33 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, | |
| auto output_start_offset = i * output_width; | ||
|
|
||
| for (auto j = 0; j < index_size; j++) { | ||
| IndexT index_value = index_vec[j]; | ||
| for (auto k = 0; k < slice_size; k++) { | ||
| out_vec[output_start_offset + index_value * slice_size + k] += | ||
| input_vec[input_start_offset + j * slice_size + k]; | ||
| } | ||
| IndexT index_value = index_data[j]; | ||
| auto src = input_data + input_start_offset + j * slice_size; | ||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto p_out = p_output + output_start_offset + index_value * slice_size; | ||
| auto dst = out_data + output_start_offset + index_value * slice_size; | ||
| IndexSelectAdd<DeviceContext, T> index_select_add; | ||
| index_select_add(context, slice_size, src, p_out, dst); | ||
| } | ||
| } | ||
| x_grad->mutable_data<T>(context.GetPlace()); | ||
| framework::TensorFromVector(out_vec, context.device_context(), x_grad); | ||
| x_grad->Resize(output_dim); | ||
| } | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| class IndexSelectGradKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto* index_var = context.InputVar("Index"); | ||
| auto* x_grad_var = context.OutputVar(framework::GradVarName("X")); | ||
| auto* out_grad_var = context.InputVar(framework::GradVarName("Out")); | ||
| auto* x_grad = | ||
| context.Output<framework::LoDTensor>(framework::GradVarName("X")); | ||
| auto* index = context.Input<framework::LoDTensor>("Index"); | ||
| auto* out_grad = | ||
| context.Input<framework::LoDTensor>(framework::GradVarName("Out")); | ||
|
|
||
| auto& index = index_var->Get<LoDTensor>(); | ||
| auto& out_grad = out_grad_var->Get<LoDTensor>(); | ||
| auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>(); | ||
| int dim = context.Attr<int>("dim"); | ||
| if (dim < 0) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line212 - Line219 可以改成: |
||
| dim += out_grad.dims().size(); | ||
| dim += out_grad->dims().size(); | ||
| } | ||
| const auto& index_type = index->type(); | ||
|
|
||
| const auto& index_type = index.type(); | ||
| bool index_type_match = index_type == framework::proto::VarType::INT32 || | ||
| index_type == framework::proto::VarType::INT64; | ||
| PADDLE_ENFORCE_EQ(index_type_match, true, | ||
|
|
@@ -209,9 +230,11 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { | |
| framework::proto::VarType::INT64))); | ||
|
|
||
| if (index_type == framework::proto::VarType::INT32) { | ||
| IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim); | ||
| IndexSelectGradInner<DeviceContext, T, int>(context, out_grad, index, | ||
| x_grad, dim); | ||
| } else if (index_type == framework::proto::VarType::INT64) { | ||
| IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim); | ||
| IndexSelectGradInner<DeviceContext, T, int64_t>(context, out_grad, index, | ||
| x_grad, dim); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用的
blas的时候,可以测一下不同OMP设置情况下的加速比。