-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimization of index_select op backward #32955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
3a1c466
7ba619f
090e046
a2b54bb
cdd297b
3a6faf1
252eefd
aa8057d
92ff317
9beb43b
d2f9aa8
d121f02
de92b58
0efa130
3b62a64
f3c1fb0
c76a2da
72b1c71
7c003b1
8a47e37
3555551
1cba37c
3bdc32c
c035b19
cfb18c7
b341849
f7ee626
e6536f0
04afbfd
62c570f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
|
|
||
| // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
|
|
@@ -11,47 +12,40 @@ | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在Conversation的Comment区域要描述本次PR的目的,PR修改前后性能变化情况等信息。 |
||
|
|
||
| #pragma once | ||
| #include <vector> | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/operators/jit/kernels.h" | ||
Zjq9409 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include "paddle/fluid/platform/cpu_info.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
| using LoDTensor = framework::LoDTensor; | ||
| using DDim = framework::DDim; | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 必要的空行有助于阅读代码,不要删除。 |
||
| template <typename T, typename IndexT = int> | ||
| void IndexSelectInner(const framework::ExecutionContext& context, | ||
| const LoDTensor& input, const LoDTensor& index, | ||
| LoDTensor* output, int dim) { | ||
| auto input_dim = input.dims(); | ||
| auto input_dim_size = input_dim.size(); | ||
| auto output_dim = output->dims(); | ||
|
|
||
| auto slice_size = 1; | ||
| for (auto i = dim + 1; i < input_dim_size; i++) { | ||
| slice_size *= input_dim[i]; | ||
| } | ||
|
|
||
| auto input_width = slice_size * input_dim[dim]; | ||
| auto output_width = slice_size * output_dim[dim]; | ||
|
|
||
| auto outer_nums = 1; | ||
| for (auto i = 0; i < dim; i++) { | ||
| outer_nums *= input_dim[i]; | ||
| } | ||
|
|
||
| auto index_size = index.dims()[0]; | ||
|
|
||
| std::vector<T> input_vec; | ||
| std::vector<IndexT> index_vec; | ||
| TensorToVector(input, context.device_context(), &input_vec); | ||
| TensorToVector(index, context.device_context(), &index_vec); | ||
| std::vector<T> out_vec(output->numel()); | ||
|
|
||
| for (int i = 0; i < index_size; i++) { | ||
| PADDLE_ENFORCE_GE( | ||
| index_vec[i], 0, | ||
|
|
@@ -68,16 +62,13 @@ void IndexSelectInner(const framework::ExecutionContext& context, | |
| "value.", | ||
| input_dim[dim], index_vec[i])); | ||
| } | ||
|
|
||
| VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums | ||
| << "; slice_size: " << slice_size << "; input_width: " << input_width | ||
| << "; output_width: " << output_width | ||
| << "; index_size: " << index_size; | ||
|
|
||
| for (auto i = 0; i < outer_nums; i++) { | ||
| auto input_start_offset = i * input_width; | ||
| auto output_start_offset = i * output_width; | ||
|
|
||
| for (auto j = 0; j < index_size; j++) { | ||
| IndexT index_value = index_vec[j]; | ||
| for (auto k = 0; k < slice_size; k++) { | ||
|
|
@@ -98,16 +89,13 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| auto* inputs_var = context.InputVar("X"); | ||
| auto* index_var = context.InputVar("Index"); | ||
| auto* output_var = context.OutputVar("Out"); | ||
|
|
||
| auto& inputs = inputs_var->Get<LoDTensor>(); | ||
| auto& index = index_var->Get<LoDTensor>(); | ||
| auto* output = output_var->GetMutable<framework::LoDTensor>(); | ||
|
|
||
| int dim = context.Attr<int>("dim"); | ||
| if (dim < 0) { | ||
| dim += inputs.dims().size(); | ||
| } | ||
|
|
||
| const auto& index_type = index.type(); | ||
| bool index_type_match = index_type == framework::proto::VarType::INT32 || | ||
| index_type == framework::proto::VarType::INT64; | ||
|
|
@@ -120,7 +108,6 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| framework::proto::VarType::INT32), | ||
| paddle::framework::DataTypeToString( | ||
| framework::proto::VarType::INT64))); | ||
|
|
||
| if (index_type == framework::proto::VarType::INT32) { | ||
| IndexSelectInner<T, int>(context, inputs, index, output, dim); | ||
| } else if (index_type == framework::proto::VarType::INT64) { | ||
|
|
@@ -129,53 +116,117 @@ class IndexSelectKernel : public framework::OpKernel<T> { | |
| } | ||
| }; | ||
|
|
||
| #if ((!defined __NVCC__) && (!defined __HIPCC__)) | ||
|
||
|
|
||
| template <typename platform::cpu_isa_t isa, typename T, class Enable = void> | ||
| struct IndexSelectAdd { | ||
| void operator()(int n, const T* src, T* dst) { | ||
| for (int i = 0; i < n; i++) { | ||
| dst[i] += src[i]; | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| struct IndexSelectAdd< | ||
|
||
| platform::avx, T, | ||
| typename std::enable_if<!std::is_floating_point<T>::value>::type> { | ||
| void operator()(const int n, const T* src, T* dst) { | ||
| for (int i = 0; i < n; i++) { | ||
| dst[i] += src[i]; | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| // description: Index addition uses intel intrinsic instruction set to read and | ||
| // write data in parallel | ||
| template <typename T> | ||
| struct IndexSelectAdd< | ||
| platform::avx, T, | ||
| typename std::enable_if<std::is_floating_point<T>::value>::type> { | ||
| void operator()(const int n, const T* src, T* dst) { | ||
| #ifdef __AVX__ | ||
| int block = 0; | ||
| if (std::is_same<T, float>::value) { | ||
| block = YMM_FLOAT_BLOCK; | ||
| } else if (std::is_same<T, double>::value) { | ||
| block = XMM_FLOAT_BLOCK; | ||
| } | ||
| int i = 0; | ||
| int end = n & ~(block - 1); | ||
| if (std::is_same<T, float>::value) { | ||
| for (i = 0; i < end; i += block) { | ||
| _mm256_storeu_ps(reinterpret_cast<float*>(dst) + i, | ||
| _mm256_add_ps(_mm256_loadu_ps((const float*)dst + i), | ||
| _mm256_loadu_ps((const float*)src + i))); | ||
| } | ||
| } else if (std::is_same<T, double>::value) { | ||
| for (i = 0; i < end; i += block) { | ||
| _mm256_storeu_pd( | ||
| reinterpret_cast<double*>(dst) + i, | ||
| _mm256_add_pd(_mm256_loadu_pd((const double*)dst + i), | ||
| _mm256_loadu_pd((const double*)src + i))); | ||
| } | ||
| } | ||
| for (; i < n; i++) { | ||
| dst[i] += src[i]; | ||
| } | ||
| #else | ||
| IndexSelectAdd<platform::isa_any, T> index_select_add_any; | ||
| index_select_add_any(n, src, dst); | ||
| #endif | ||
| } | ||
| }; | ||
| #endif | ||
|
|
||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| template <typename T, typename IndexT = int> | ||
| void IndexSelectGradInner(const framework::ExecutionContext& context, | ||
| const LoDTensor& out_grad, const LoDTensor& index, | ||
| LoDTensor* x_grad, int dim) { | ||
| std::vector<T> input_vec; | ||
| std::vector<IndexT> index_vec; | ||
| TensorToVector(out_grad, context.device_context(), &input_vec); | ||
| TensorToVector(index, context.device_context(), &index_vec); | ||
|
|
||
| const T* input_data = out_grad.data<T>(); | ||
| const IndexT* index_data = index.data<IndexT>(); | ||
| T* out_data = x_grad->mutable_data<T>(context.GetPlace()); | ||
| auto input_dim = out_grad.dims(); | ||
| auto input_dim_size = input_dim.size(); | ||
| auto output_dim = x_grad->dims(); | ||
| std::vector<T> out_vec(x_grad->numel(), 0); | ||
| std::memset(out_data, 0.0, x_grad->numel() * sizeof(T)); | ||
|
||
|
|
||
| auto slice_size = 1; | ||
| for (auto i = dim + 1; i < input_dim_size; i++) { | ||
| slice_size *= input_dim[i]; | ||
| } | ||
|
|
||
| auto input_width = slice_size * input_dim[dim]; | ||
| auto output_width = slice_size * output_dim[dim]; | ||
|
|
||
| auto outer_nums = 1; | ||
| for (auto i = 0; i < dim; i++) { | ||
| outer_nums *= input_dim[i]; | ||
| } | ||
|
|
||
| auto index_size = index.dims()[0]; | ||
| VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums | ||
| << "; slice_size: " << slice_size << "; input_width: " << input_width | ||
| << "; output_width: " << output_width | ||
| << "; index_size: " << index_size; | ||
|
|
||
| for (auto i = 0; i < outer_nums; i++) { | ||
| auto input_start_offset = i * input_width; | ||
| auto output_start_offset = i * output_width; | ||
|
|
||
| for (auto j = 0; j < index_size; j++) { | ||
| IndexT index_value = index_vec[j]; | ||
| for (auto k = 0; k < slice_size; k++) { | ||
| out_vec[output_start_offset + index_value * slice_size + k] += | ||
| input_vec[input_start_offset + j * slice_size + k]; | ||
| } | ||
| IndexT index_value = index_data[j]; | ||
| auto src = input_data + input_start_offset + j * slice_size; | ||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto dst = out_data + output_start_offset + index_value * slice_size; | ||
|
|
||
| #if ((!defined __NVCC__) && (!defined __HIPCC__)) | ||
|
|
||
| #ifdef __AVX__ | ||
| IndexSelectAdd<platform::avx, T> index_select_add_avx; | ||
| index_select_add_avx(slice_size, src, dst); | ||
| #else | ||
| IndexSelectAdd<platform::isa_any, T> index_select_add_any; | ||
| index_select_add_any(slice_size, src, dst); | ||
| #endif | ||
|
|
||
| #endif | ||
| } | ||
| } | ||
| x_grad->mutable_data<T>(context.GetPlace()); | ||
| framework::TensorFromVector(out_vec, context.device_context(), x_grad); | ||
| x_grad->Resize(output_dim); | ||
| } | ||
|
|
||
|
|
@@ -186,15 +237,13 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { | |
| auto* index_var = context.InputVar("Index"); | ||
| auto* x_grad_var = context.OutputVar(framework::GradVarName("X")); | ||
| auto* out_grad_var = context.InputVar(framework::GradVarName("Out")); | ||
|
|
||
| auto& index = index_var->Get<LoDTensor>(); | ||
| auto& out_grad = out_grad_var->Get<LoDTensor>(); | ||
| auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>(); | ||
| int dim = context.Attr<int>("dim"); | ||
| if (dim < 0) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line212 - Line219 可以改成: |
||
| dim += out_grad.dims().size(); | ||
| } | ||
|
|
||
| const auto& index_type = index.type(); | ||
| bool index_type_match = index_type == framework::proto::VarType::INT32 || | ||
| index_type == framework::proto::VarType::INT64; | ||
|
|
@@ -207,14 +256,12 @@ class IndexSelectGradKernel : public framework::OpKernel<T> { | |
| framework::proto::VarType::INT32), | ||
| paddle::framework::DataTypeToString( | ||
| framework::proto::VarType::INT64))); | ||
|
|
||
| if (index_type == framework::proto::VarType::INT32) { | ||
| IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim); | ||
| } else if (index_type == framework::proto::VarType::INT64) { | ||
| IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
Uh oh!
There was an error while loading. Please reload this page.