Skip to content
Merged
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3a1c466
optimization of index_select op backward
Zjq9409 May 18, 2021
7ba619f
optimization of index_select op backward
Zjq9409 May 18, 2021
090e046
add compile parameter
Zjq9409 May 18, 2021
a2b54bb
optimization of index_select op backward
Zjq9409 May 19, 2021
cdd297b
optimization of index_select op backward
Zjq9409 May 19, 2021
3a6faf1
optimization of index_select op backward
Zjq9409 May 19, 2021
252eefd
optimization of index_select op backward
Zjq9409 May 19, 2021
aa8057d
optimization of index_select op backward
Zjq9409 May 19, 2021
92ff317
optimization of index_select op backward
Zjq9409 May 19, 2021
9beb43b
optimization of index_select op backward
Zjq9409 May 19, 2021
d2f9aa8
optimization of index_select op backward
Zjq9409 May 25, 2021
d121f02
optimization of index_select op backward
Zjq9409 May 25, 2021
de92b58
optimization of index_select backward
Zjq9409 Jun 10, 2021
0efa130
optimization of index_select op backward
Zjq9409 Jun 15, 2021
3b62a64
optimization of index_select op backward
Zjq9409 Jun 15, 2021
f3c1fb0
optimization of index_select op backward
Zjq9409 Jun 15, 2021
c76a2da
optimization of index_select op backward
Zjq9409 Jun 15, 2021
72b1c71
Merge branch 'develop' into Optimizaition_of_index_select_backward_cp…
Zjq9409 Jun 15, 2021
7c003b1
optimization of index_select op backward
Zjq9409 Jun 16, 2021
8a47e37
optimization of index_select op
Zjq9409 Jul 5, 2021
3555551
optimization of index_select backward
Zjq9409 Jul 6, 2021
1cba37c
Optimizaition of index select backward by blas
Zjq9409 Jul 6, 2021
3bdc32c
Optimizaition of index select backward by blas
Zjq9409 Jul 6, 2021
c035b19
modify add operator to blas
Zjq9409 Jul 6, 2021
cfb18c7
Merge branch 'develop' into Optimizaition_of_index_select_backward_cp…
Zjq9409 Jul 6, 2021
b341849
optimization index select backward
Zjq9409 Jul 7, 2021
f7ee626
modify template
Zjq9409 Jul 7, 2021
e6536f0
modify template
Zjq9409 Jul 7, 2021
04afbfd
optimization of index select
Zjq9409 Jul 7, 2021
62c570f
optimization index_select backward
Zjq9409 Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 85 additions & 38 deletions paddle/fluid/operators/index_select_op.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,47 +12,40 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在Conversation的Comment区域要描述本次PR的目的,PR修改前后性能变化情况等信息。


#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必要的空行有助于阅读代码,不要删除。

template <typename T, typename IndexT = int>
void IndexSelectInner(const framework::ExecutionContext& context,
const LoDTensor& input, const LoDTensor& index,
LoDTensor* output, int dim) {
auto input_dim = input.dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();

auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}

auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];

auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}

auto index_size = index.dims()[0];

std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(input, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);
std::vector<T> out_vec(output->numel());

for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_vec[i], 0,
Expand All @@ -68,16 +62,13 @@ void IndexSelectInner(const framework::ExecutionContext& context,
"value.",
input_dim[dim], index_vec[i]));
}

VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
<< "; index_size: " << index_size;

for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;

for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
Expand All @@ -98,16 +89,13 @@ class IndexSelectKernel : public framework::OpKernel<T> {
auto* inputs_var = context.InputVar("X");
auto* index_var = context.InputVar("Index");
auto* output_var = context.OutputVar("Out");

auto& inputs = inputs_var->Get<LoDTensor>();
auto& index = index_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<framework::LoDTensor>();

int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}

const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
Expand All @@ -120,7 +108,6 @@ class IndexSelectKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));

if (index_type == framework::proto::VarType::INT32) {
IndexSelectInner<T, int>(context, inputs, index, output, dim);
} else if (index_type == framework::proto::VarType::INT64) {
Expand All @@ -129,53 +116,117 @@ class IndexSelectKernel : public framework::OpKernel<T> {
}
};

#if ((!defined __NVCC__) && (!defined __HIPCC__))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的宏是否还有必要


template <typename platform::cpu_isa_t isa, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(int n, const T* src, T* dst) {
for (int i = 0; i < n; i++) {
dst[i] += src[i];
}
}
};

template <typename T>
struct IndexSelectAdd<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段仿函数的意义感觉不大,感觉除了浮点之外采用的是下述通用形式。

template <typename platform::cpu_isa_t isa, typename T, class Enable = void>
 struct IndexSelectAdd {
   void operator()(int n, const T* src, T* dst) {
     for (int i = 0; i < n; i++) {
       dst[i] += src[i];
     }
   }
 };

platform::avx, T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const int n, const T* src, T* dst) {
for (int i = 0; i < n; i++) {
dst[i] += src[i];
}
}
};

// description: Index addition uses intel intrinsic instruction set to read and
// write data in parallel
template <typename T>
struct IndexSelectAdd<
platform::avx, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const int n, const T* src, T* dst) {
#ifdef __AVX__
int block = 0;
if (std::is_same<T, float>::value) {
block = YMM_FLOAT_BLOCK;
} else if (std::is_same<T, double>::value) {
block = XMM_FLOAT_BLOCK;
}
int i = 0;
int end = n & ~(block - 1);
if (std::is_same<T, float>::value) {
for (i = 0; i < end; i += block) {
_mm256_storeu_ps(reinterpret_cast<float*>(dst) + i,
_mm256_add_ps(_mm256_loadu_ps((const float*)dst + i),
_mm256_loadu_ps((const float*)src + i)));
}
} else if (std::is_same<T, double>::value) {
for (i = 0; i < end; i += block) {
_mm256_storeu_pd(
reinterpret_cast<double*>(dst) + i,
_mm256_add_pd(_mm256_loadu_pd((const double*)dst + i),
_mm256_loadu_pd((const double*)src + i)));
}
}
for (; i < n; i++) {
dst[i] += src[i];
}
#else
IndexSelectAdd<platform::isa_any, T> index_select_add_any;
index_select_add_any(n, src, dst);
#endif
}
};
#endif

template <typename T, typename IndexT = int>
void IndexSelectGradInner(const framework::ExecutionContext& context,
const LoDTensor& out_grad, const LoDTensor& index,
LoDTensor* x_grad, int dim) {
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(out_grad, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);

const T* input_data = out_grad.data<T>();
const IndexT* index_data = index.data<IndexT>();
T* out_data = x_grad->mutable_data<T>(context.GetPlace());
auto input_dim = out_grad.dims();
auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims();
std::vector<T> out_vec(x_grad->numel(), 0);
std::memset(out_data, 0.0, x_grad->numel() * sizeof(T));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用SetConstant,另外初始化为0的部分,放到L196后的for循环里面,每次初始化一部分,对cache是不是友好些?


auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}

auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];

auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}

auto index_size = index.dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
<< "; index_size: " << index_size;

for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;

for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + index_value * slice_size + k] +=
input_vec[input_start_offset + j * slice_size + k];
}
IndexT index_value = index_data[j];
auto src = input_data + input_start_offset + j * slice_size;
auto dst = out_data + output_start_offset + index_value * slice_size;

#if ((!defined __NVCC__) && (!defined __HIPCC__))

#ifdef __AVX__
IndexSelectAdd<platform::avx, T> index_select_add_avx;
index_select_add_avx(slice_size, src, dst);
#else
IndexSelectAdd<platform::isa_any, T> index_select_add_any;
index_select_add_any(slice_size, src, dst);
#endif

#endif
}
}
x_grad->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), x_grad);
x_grad->Resize(output_dim);
}

Expand All @@ -186,15 +237,13 @@ class IndexSelectGradKernel : public framework::OpKernel<T> {
auto* index_var = context.InputVar("Index");
auto* x_grad_var = context.OutputVar(framework::GradVarName("X"));
auto* out_grad_var = context.InputVar(framework::GradVarName("Out"));

auto& index = index_var->Get<LoDTensor>();
auto& out_grad = out_grad_var->Get<LoDTensor>();
auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>();
int dim = context.Attr<int>("dim");
if (dim < 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line212 - Line219 可以改成:

    auto *x_grad = ctx.Input<framework::LoDTensor>("X");
    auto *index = ctx.Input<framework::LoDTensor>("Index");
    auto *out_grad = ctx.Output<framework::LoDTensor>("Out");

dim += out_grad.dims().size();
}

const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
Expand All @@ -207,14 +256,12 @@ class IndexSelectGradKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));

if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim);
}
}
};

} // namespace operators
} // namespace paddle