-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimization of index_select op backward #32955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
3a1c466
7ba619f
090e046
a2b54bb
cdd297b
3a6faf1
252eefd
aa8057d
92ff317
9beb43b
d2f9aa8
d121f02
de92b58
0efa130
3b62a64
f3c1fb0
c76a2da
72b1c71
7c003b1
8a47e37
3555551
1cba37c
3bdc32c
c035b19
cfb18c7
b341849
f7ee626
e6536f0
04afbfd
62c570f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,47 +11,50 @@ | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
| #ifdef _WIN32 | ||
| #if defined(__AVX2__) | ||
| #include <immintrin.h> // avx2 | ||
| #elif defined(__AVX__) | ||
| #include <intrin.h> // avx | ||
| #endif // AVX | ||
| #else // WIN32 | ||
| #ifdef __AVX__ | ||
| #include <immintrin.h> | ||
| #endif | ||
| #endif // WIN32 | ||
Zjq9409 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include <vector> | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
|
|
||
| #include "paddle/fluid/operators/jit/kernels.h" | ||
Zjq9409 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include "paddle/fluid/platform/cpu_info.h" | ||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
| using LoDTensor = framework::LoDTensor; | ||
| using DDim = framework::DDim; | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 必要的空行有助于阅读代码,不要删除。 |
||
| template <typename T, typename IndexT = int> | ||
| void IndexSelectInner(const framework::ExecutionContext& context, | ||
| const LoDTensor& input, const LoDTensor& index, | ||
| LoDTensor* output, int dim) { | ||
| auto input_dim = input.dims(); | ||
| auto input_dim_size = input_dim.size(); | ||
| auto output_dim = output->dims(); | ||
|
|
||
| auto slice_size = 1; | ||
| for (auto i = dim + 1; i < input_dim_size; i++) { | ||
| slice_size *= input_dim[i]; | ||
| } | ||
|
|
||
| auto input_width = slice_size * input_dim[dim]; | ||
| auto output_width = slice_size * output_dim[dim]; | ||
|
|
||
| auto outer_nums = 1; | ||
| for (auto i = 0; i < dim; i++) { | ||
| outer_nums *= input_dim[i]; | ||
| } | ||
|
|
||
| auto index_size = index.dims()[0]; | ||
|
|
||
| std::vector<T> input_vec; | ||
| std::vector<IndexT> index_vec; | ||
| TensorToVector(input, context.device_context(), &input_vec); | ||
| TensorToVector(index, context.device_context(), &index_vec); | ||
| std::vector<T> out_vec(output->numel()); | ||
|
|
||
| for (int i = 0; i < index_size; i++) { | ||
| PADDLE_ENFORCE_GE( | ||
| index_vec[i], 0, | ||
|
|
@@ -68,16 +71,13 @@ void IndexSelectInner(const framework::ExecutionContext& context, | |
| "value.", | ||
| input_dim[dim], index_vec[i])); | ||
| } | ||
|
|
||
| VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums | ||
| << "; slice_size: " << slice_size << "; input_width: " << input_width | ||
| << "; output_width: " << output_width | ||
| << "; index_size: " << index_size; | ||
|
|
||
| for (auto i = 0; i < outer_nums; i++) { | ||
| auto input_start_offset = i * input_width; | ||
| auto output_start_offset = i * output_width; | ||
|
|
||
| for (auto j = 0; j < index_size; j++) { | ||
| IndexT index_value = index_vec[j]; | ||
| for (auto k = 0; k < slice_size; k++) { | ||
|
|
@@ -98,16 +98,13 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| auto* inputs_var = context.InputVar("X"); | ||
| auto* index_var = context.InputVar("Index"); | ||
| auto* output_var = context.OutputVar("Out"); | ||
|
|
||
| auto& inputs = inputs_var->Get<LoDTensor>(); | ||
| auto& index = index_var->Get<LoDTensor>(); | ||
| auto* output = output_var->GetMutable<framework::LoDTensor>(); | ||
|
|
||
| int dim = context.Attr<int>("dim"); | ||
| if (dim < 0) { | ||
| dim += inputs.dims().size(); | ||
| } | ||
|
|
||
| const auto& index_type = index.type(); | ||
| bool index_type_match = index_type == framework::proto::VarType::INT32 || | ||
| index_type == framework::proto::VarType::INT64; | ||
|
|
@@ -120,7 +117,6 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| framework::proto::VarType::INT32), | ||
| paddle::framework::DataTypeToString( | ||
| framework::proto::VarType::INT64))); | ||
|
|
||
| if (index_type == framework::proto::VarType::INT32) { | ||
| IndexSelectInner<T, int>(context, inputs, index, output, dim); | ||
| } else if (index_type == framework::proto::VarType::INT64) { | ||
|
|
@@ -129,53 +125,97 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| } | ||
| }; | ||
|
|
||
| #if ((!defined __NVCC__) && (!defined __HIPCC__)) | ||
|
||
| template <typename T> | ||
| void index_sum(const size_t n, const T* src, T* dst) { | ||
| #ifdef __AVX__ | ||
| constexpr int block = YMM_FLOAT_BLOCK; | ||
| unsigned int i, end; | ||
| i = end = 0; | ||
| end = n & ~(block - 1); | ||
| for (i = 0; i < end; i += block) { | ||
| _mm256_storeu_ps(reinterpret_cast<float*>(dst) + i, | ||
| _mm256_add_ps(_mm256_loadu_ps((const float*)dst + i), | ||
| _mm256_loadu_ps((const float*)src + i))); | ||
| } | ||
| for (; i < n; i++) { | ||
| dst[i] += src[i]; | ||
| } | ||
| #else | ||
| for (size_t k = 0; k < n; k++) { | ||
| dst[k] += src[k]; | ||
| } | ||
| #endif | ||
| } | ||
|
|
||
| template <> | ||
| void index_sum(const size_t n, const double* src, double* dst) { | ||
Zjq9409 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #ifdef __AVX__ | ||
| constexpr int block = XMM_FLOAT_BLOCK; | ||
| unsigned int i, end; | ||
| i = end = 0; | ||
| end = n & ~(block - 1); | ||
| for (i = 0; i < end; i += block) { | ||
| _mm256_storeu_pd(reinterpret_cast<double*>(dst) + i, | ||
| _mm256_add_pd(_mm256_loadu_pd((const double*)dst + i), | ||
| _mm256_loadu_pd((const double*)src + i))); | ||
| } | ||
| for (; i < n; i++) { | ||
| dst[i] += src[i]; | ||
| } | ||
| #else | ||
| for (size_t k = 0; k < n; k++) { | ||
| dst[k] += src[k]; | ||
| } | ||
| #endif | ||
| } | ||
| #endif | ||
|
|
||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| template <typename T, typename IndexT = int> | ||
| void IndexSelectGradInner(const framework::ExecutionContext& context, | ||
| const LoDTensor& out_grad, const LoDTensor& index, | ||
| LoDTensor* x_grad, int dim) { | ||
| std::vector<T> input_vec; | ||
| std::vector<IndexT> index_vec; | ||
| TensorToVector(out_grad, context.device_context(), &input_vec); | ||
| TensorToVector(index, context.device_context(), &index_vec); | ||
|
|
||
| const T* input_data = out_grad.data<T>(); | ||
| const IndexT* index_data = index.data<IndexT>(); | ||
| T* out_data = x_grad->mutable_data<T>(context.GetPlace()); | ||
| auto input_dim = out_grad.dims(); | ||
| auto input_dim_size = input_dim.size(); | ||
| auto output_dim = x_grad->dims(); | ||
| std::vector<T> out_vec(x_grad->numel(), 0); | ||
|
|
||
| std::memset(out_data, 0.0, x_grad->numel() * sizeof(T)); | ||
|
||
| auto slice_size = 1; | ||
| for (auto i = dim + 1; i < input_dim_size; i++) { | ||
| slice_size *= input_dim[i]; | ||
| } | ||
|
|
||
| auto input_width = slice_size * input_dim[dim]; | ||
| auto output_width = slice_size * output_dim[dim]; | ||
|
|
||
| auto outer_nums = 1; | ||
| for (auto i = 0; i < dim; i++) { | ||
| outer_nums *= input_dim[i]; | ||
| } | ||
|
|
||
| auto index_size = index.dims()[0]; | ||
| VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums | ||
| << "; slice_size: " << slice_size << "; input_width: " << input_width | ||
| << "; output_width: " << output_width | ||
| << "; index_size: " << index_size; | ||
|
|
||
| for (auto i = 0; i < outer_nums; i++) { | ||
| auto input_start_offset = i * input_width; | ||
| auto output_start_offset = i * output_width; | ||
|
|
||
| for (auto j = 0; j < index_size; j++) { | ||
| IndexT index_value = index_vec[j]; | ||
| IndexT index_value = index_data[j]; | ||
| #ifdef __AVX__ | ||
| auto src = input_data + input_start_offset + j * slice_size; | ||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto dst = out_data + output_start_offset + index_value * slice_size; | ||
| #if ((!defined __NVCC__) && (!defined __HIPCC__)) | ||
| index_sum(slice_size, src, dst); | ||
| #endif | ||
| #else | ||
| for (auto k = 0; k < slice_size; k++) { | ||
| out_vec[output_start_offset + index_value * slice_size + k] += | ||
| input_vec[input_start_offset + j * slice_size + k]; | ||
| out_data[output_start_offset + index_value * slice_size + k] += | ||
| input_data[input_start_offset + j * slice_size + k]; | ||
| } | ||
| #endif | ||
| } | ||
| } | ||
| x_grad->mutable_data<T>(context.GetPlace()); | ||
| framework::TensorFromVector(out_vec, context.device_context(), x_grad); | ||
| x_grad->Resize(output_dim); | ||
| } | ||
|
|
||
|
|
@@ -186,15 +226,13 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { | |
| auto* index_var = context.InputVar("Index"); | ||
| auto* x_grad_var = context.OutputVar(framework::GradVarName("X")); | ||
| auto* out_grad_var = context.InputVar(framework::GradVarName("Out")); | ||
|
|
||
| auto& index = index_var->Get<LoDTensor>(); | ||
| auto& out_grad = out_grad_var->Get<LoDTensor>(); | ||
| auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>(); | ||
| int dim = context.Attr<int>("dim"); | ||
| if (dim < 0) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line212 - Line219 可以改成: |
||
| dim += out_grad.dims().size(); | ||
| } | ||
|
|
||
| const auto& index_type = index.type(); | ||
| bool index_type_match = index_type == framework::proto::VarType::INT32 || | ||
| index_type == framework::proto::VarType::INT64; | ||
|
|
@@ -207,14 +245,12 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { | |
| framework::proto::VarType::INT32), | ||
| paddle::framework::DataTypeToString( | ||
| framework::proto::VarType::INT64))); | ||
|
|
||
| if (index_type == framework::proto::VarType::INT32) { | ||
| IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim); | ||
| } else if (index_type == framework::proto::VarType::INT64) { | ||
| IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在Conversation的Comment区域要描述本次PR的目的,PR修改前后性能变化情况等信息。