Skip to content
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3a1c466
optimization of index_select op backward
Zjq9409 May 18, 2021
7ba619f
optimization of index_select op backward
Zjq9409 May 18, 2021
090e046
add compile parameter
Zjq9409 May 18, 2021
a2b54bb
optimization of index_select op backward
Zjq9409 May 19, 2021
cdd297b
optimization of index_select op backward
Zjq9409 May 19, 2021
3a6faf1
optimization of index_select op backward
Zjq9409 May 19, 2021
252eefd
optimization of index_select op backward
Zjq9409 May 19, 2021
aa8057d
optimization of index_select op backward
Zjq9409 May 19, 2021
92ff317
optimization of index_select op backward
Zjq9409 May 19, 2021
9beb43b
optimization of index_select op backward
Zjq9409 May 19, 2021
d2f9aa8
optimization of index_select op backward
Zjq9409 May 25, 2021
d121f02
optimization of index_select op backward
Zjq9409 May 25, 2021
de92b58
optimization of index_select backward
Zjq9409 Jun 10, 2021
0efa130
optimization of index_select op backward
Zjq9409 Jun 15, 2021
3b62a64
optimization of index_select op backward
Zjq9409 Jun 15, 2021
f3c1fb0
optimization of index_select op backward
Zjq9409 Jun 15, 2021
c76a2da
optimization of index_select op backward
Zjq9409 Jun 15, 2021
72b1c71
Merge branch 'develop' into Optimizaition_of_index_select_backward_cp…
Zjq9409 Jun 15, 2021
7c003b1
optimization of index_select op backward
Zjq9409 Jun 16, 2021
8a47e37
optimization of index_select op
Zjq9409 Jul 5, 2021
3555551
optimization of index_select backward
Zjq9409 Jul 6, 2021
1cba37c
Optimizaition of index select backward by blas
Zjq9409 Jul 6, 2021
3bdc32c
Optimizaition of index select backward by blas
Zjq9409 Jul 6, 2021
c035b19
modify add operator to blas
Zjq9409 Jul 6, 2021
cfb18c7
Merge branch 'develop' into Optimizaition_of_index_select_backward_cp…
Zjq9409 Jul 6, 2021
b341849
optimization index select backward
Zjq9409 Jul 7, 2021
f7ee626
modify template
Zjq9409 Jul 7, 2021
e6536f0
modify template
Zjq9409 Jul 7, 2021
04afbfd
optimization of index select
Zjq9409 Jul 7, 2021
62c570f
optimization index_select backward
Zjq9409 Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions paddle/fluid/operators/index_select_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
Expand All @@ -38,7 +42,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,

auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];

auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
Expand Down Expand Up @@ -77,7 +80,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;

for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
Expand All @@ -98,7 +100,6 @@ class IndexSelectKernel : public framework::OpKernel<T> {
auto* inputs_var = context.InputVar("X");
auto* index_var = context.InputVar("Index");
auto* output_var = context.OutputVar("Out");

auto& inputs = inputs_var->Get<LoDTensor>();
auto& index = index_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<framework::LoDTensor>();
Expand All @@ -107,8 +108,8 @@ class IndexSelectKernel : public framework::OpKernel<T> {
if (dim < 0) {
dim += inputs.dims().size();
}

const auto& index_type = index.type();

bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
Expand All @@ -129,19 +130,41 @@ class IndexSelectKernel : public framework::OpKernel<T> {
}
};

template <typename T, typename IndexT = int>
template <typename DeviceContext, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(const framework::ExecutionContext& ctx, int slice_size,
const T* src_pointer, const T* p_pointer, T* dist_pointer) {
for (int i = 0; i < slice_size; i++) {
dist_pointer[i] = src_pointer[i] + p_pointer[i];
}
}
};
template <typename DeviceContext, typename T>
struct IndexSelectAdd<
DeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext& ctx, int slice_size,
const T* src_pointer, const T* p_pointer, T* dist_pointer) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用的blas的时候,可以测一下不同OMP设置情况下的加速比。

}
};

template <typename DeviceContext, typename T, typename IndexT = int>
void IndexSelectGradInner(const framework::ExecutionContext& context,
const LoDTensor& out_grad, const LoDTensor& index,
const LoDTensor* out_grad, const LoDTensor* index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要修改参数的类型,不用修改的输入用const Tensor&类型。

LoDTensor* x_grad, int dim) {
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
TensorToVector(out_grad, context.device_context(), &input_vec);
TensorToVector(index, context.device_context(), &index_vec);

auto input_dim = out_grad.dims();
const T* input_data = out_grad->data<T>();
const IndexT* index_data = index->data<IndexT>();
const T* p_output = x_grad->mutable_data<T>(context.GetPlace());
T* out_data = x_grad->mutable_data<T>(context.GetPlace());
auto input_dim = out_grad->dims();
auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims();
std::vector<T> out_vec(x_grad->numel(), 0);

auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_constant;
set_constant(dev_ctx, x_grad, static_cast<T>(0.0));

auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
Expand All @@ -156,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
outer_nums *= input_dim[i];
}

auto index_size = index.dims()[0];
auto index_size = index->dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
Expand All @@ -167,35 +190,33 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
auto output_start_offset = i * output_width;

for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + index_value * slice_size + k] +=
input_vec[input_start_offset + j * slice_size + k];
}
IndexT index_value = index_data[j];
auto src = input_data + input_start_offset + j * slice_size;
auto p_out = p_output + output_start_offset + index_value * slice_size;
auto dst = out_data + output_start_offset + index_value * slice_size;
IndexSelectAdd<DeviceContext, T> index_select_add;
index_select_add(context, slice_size, src, p_out, dst);
}
}
x_grad->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), x_grad);
x_grad->Resize(output_dim);
}

template <typename DeviceContext, typename T>
class IndexSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* index_var = context.InputVar("Index");
auto* x_grad_var = context.OutputVar(framework::GradVarName("X"));
auto* out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<framework::LoDTensor>("Index");
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));

auto& index = index_var->Get<LoDTensor>();
auto& out_grad = out_grad_var->Get<LoDTensor>();
auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>();
int dim = context.Attr<int>("dim");
if (dim < 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line212 - Line219 可以改成:

    auto *x_grad = ctx.Input<framework::LoDTensor>("X");
    auto *index = ctx.Input<framework::LoDTensor>("Index");
    auto *out_grad = ctx.Output<framework::LoDTensor>("Out");

dim += out_grad.dims().size();
dim += out_grad->dims().size();
}
const auto& index_type = index->type();

const auto& index_type = index.type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
Expand All @@ -209,9 +230,11 @@ class IndexSelectGradKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT64)));

if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim);
IndexSelectGradInner<DeviceContext, T, int>(context, out_grad, index,
x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim);
IndexSelectGradInner<DeviceContext, T, int64_t>(context, out_grad, index,
x_grad, dim);
}
}
};
Expand Down