@@ -130,7 +130,6 @@ class IndexSelectKernel : public framework::OpKernel<T> {
130130 }
131131};
132132
133- #if ((!defined __NVCC__) && (!defined __HIPCC__))
134133template <typename DeviceContext, typename T, class Enable = void >
135134struct IndexSelectAdd {
136135 void operator ()(const framework::ExecutionContext& ctx, int slice_size,
@@ -150,17 +149,16 @@ struct IndexSelectAdd<
150149 blas.VADD (slice_size, src_pointer, p_pointer, dist_pointer);
151150 }
152151};
153- #endif
154152
155153template <typename DeviceContext, typename T, typename IndexT = int >
156154void IndexSelectGradInner (const framework::ExecutionContext& context,
157- const LoDTensor& out_grad, const LoDTensor& index,
155+ const LoDTensor* out_grad, const LoDTensor* index,
158156 LoDTensor* x_grad, int dim) {
159- const T* input_data = out_grad. data <T>();
160- const IndexT* index_data = index. data <IndexT>();
157+ const T* input_data = out_grad-> data <T>();
158+ const IndexT* index_data = index-> data <IndexT>();
161159 const T* p_output = x_grad->mutable_data <T>(context.GetPlace ());
162160 T* out_data = x_grad->mutable_data <T>(context.GetPlace ());
163- auto input_dim = out_grad. dims ();
161+ auto input_dim = out_grad-> dims ();
164162 auto input_dim_size = input_dim.size ();
165163 auto output_dim = x_grad->dims ();
166164
@@ -181,7 +179,7 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
181179 outer_nums *= input_dim[i];
182180 }
183181
184- auto index_size = index. dims ()[0 ];
182+ auto index_size = index-> dims ()[0 ];
185183 VLOG (3 ) << " Index_Select_Grad_Debug; outer_nums: " << outer_nums
186184 << " ; slice_size: " << slice_size << " ; input_width: " << input_width
187185 << " ; output_width: " << output_width
@@ -196,10 +194,8 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
196194 auto src = input_data + input_start_offset + j * slice_size;
197195 auto p_out = p_output + output_start_offset + index_value * slice_size;
198196 auto dst = out_data + output_start_offset + index_value * slice_size;
199- #if ((!defined __NVCC__) && (!defined __HIPCC__))
200197 IndexSelectAdd<DeviceContext, T> index_select_add;
201198 index_select_add (context, slice_size, src, p_out, dst);
202- #endif
203199 }
204200 }
205201 x_grad->Resize (output_dim);
@@ -209,18 +205,17 @@ template <typename DeviceContext, typename T>
209205class IndexSelectGradKernel : public framework ::OpKernel<T> {
210206 public:
211207 void Compute (const framework::ExecutionContext& context) const override {
212- auto * index_var = context.InputVar (" Index" );
213- auto * x_grad_var = context.OutputVar (framework::GradVarName (" X" ));
214- auto * out_grad_var = context.InputVar (framework::GradVarName (" Out" ));
208+ auto * x_grad =
209+ context.Output <framework::LoDTensor>(framework::GradVarName (" X" ));
210+ auto * index = context.Input <framework::LoDTensor>(" Index" );
211+ auto * out_grad =
212+ context.Input <framework::LoDTensor>(framework::GradVarName (" Out" ));
215213
216- auto & index = index_var->Get <LoDTensor>();
217- auto & out_grad = out_grad_var->Get <LoDTensor>();
218- auto * x_grad = x_grad_var->GetMutable <framework::LoDTensor>();
219214 int dim = context.Attr <int >(" dim" );
220215 if (dim < 0 ) {
221- dim += out_grad. dims ().size ();
216+ dim += out_grad-> dims ().size ();
222217 }
223- const auto & index_type = index. type ();
218+ const auto & index_type = index-> type ();
224219
225220 bool index_type_match = index_type == framework::proto::VarType::INT32 ||
226221 index_type == framework::proto::VarType::INT64;
0 commit comments