Skip to content

Commit 62c570f

Browse files
committed
optimization index_select backward
1 parent 04afbfd commit 62c570f

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

paddle/fluid/operators/index_select_op.h

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ class IndexSelectKernel : public framework::OpKernel<T> {
130130
}
131131
};
132132

133-
#if ((!defined __NVCC__) && (!defined __HIPCC__))
134133
template <typename DeviceContext, typename T, class Enable = void>
135134
struct IndexSelectAdd {
136135
void operator()(const framework::ExecutionContext& ctx, int slice_size,
@@ -150,17 +149,16 @@ struct IndexSelectAdd<
150149
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
151150
}
152151
};
153-
#endif
154152

155153
template <typename DeviceContext, typename T, typename IndexT = int>
156154
void IndexSelectGradInner(const framework::ExecutionContext& context,
157-
const LoDTensor& out_grad, const LoDTensor& index,
155+
const LoDTensor* out_grad, const LoDTensor* index,
158156
LoDTensor* x_grad, int dim) {
159-
const T* input_data = out_grad.data<T>();
160-
const IndexT* index_data = index.data<IndexT>();
157+
const T* input_data = out_grad->data<T>();
158+
const IndexT* index_data = index->data<IndexT>();
161159
const T* p_output = x_grad->mutable_data<T>(context.GetPlace());
162160
T* out_data = x_grad->mutable_data<T>(context.GetPlace());
163-
auto input_dim = out_grad.dims();
161+
auto input_dim = out_grad->dims();
164162
auto input_dim_size = input_dim.size();
165163
auto output_dim = x_grad->dims();
166164

@@ -181,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
181179
outer_nums *= input_dim[i];
182180
}
183181

184-
auto index_size = index.dims()[0];
182+
auto index_size = index->dims()[0];
185183
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
186184
<< "; slice_size: " << slice_size << "; input_width: " << input_width
187185
<< "; output_width: " << output_width
@@ -196,10 +194,8 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
196194
auto src = input_data + input_start_offset + j * slice_size;
197195
auto p_out = p_output + output_start_offset + index_value * slice_size;
198196
auto dst = out_data + output_start_offset + index_value * slice_size;
199-
#if ((!defined __NVCC__) && (!defined __HIPCC__))
200197
IndexSelectAdd<DeviceContext, T> index_select_add;
201198
index_select_add(context, slice_size, src, p_out, dst);
202-
#endif
203199
}
204200
}
205201
x_grad->Resize(output_dim);
@@ -209,18 +205,17 @@ template <typename DeviceContext, typename T>
209205
class IndexSelectGradKernel : public framework::OpKernel<T> {
210206
public:
211207
void Compute(const framework::ExecutionContext& context) const override {
212-
auto* index_var = context.InputVar("Index");
213-
auto* x_grad_var = context.OutputVar(framework::GradVarName("X"));
214-
auto* out_grad_var = context.InputVar(framework::GradVarName("Out"));
208+
auto* x_grad =
209+
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
210+
auto* index = context.Input<framework::LoDTensor>("Index");
211+
auto* out_grad =
212+
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
215213

216-
auto& index = index_var->Get<LoDTensor>();
217-
auto& out_grad = out_grad_var->Get<LoDTensor>();
218-
auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>();
219214
int dim = context.Attr<int>("dim");
220215
if (dim < 0) {
221-
dim += out_grad.dims().size();
216+
dim += out_grad->dims().size();
222217
}
223-
const auto& index_type = index.type();
218+
const auto& index_type = index->type();
224219

225220
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
226221
index_type == framework::proto::VarType::INT64;

0 commit comments

Comments
 (0)