@@ -30,13 +30,15 @@ namespace phi {
3030template <typename T, typename IndexT = int >
3131void GPUIndexElementwisePutGradKernel (
3232 const phi::GPUContext& dev_ctx,
33+ const DenseTensor& out_grad,
3334 const std::vector<const DenseTensor*>& index,
3435 const std::vector<int64_t >& input_dims,
3536 const std::vector<int64_t >& input_strides,
3637 const std::vector<int64_t >& index_dims,
3738 const std::vector<int64_t >& index_strides,
3839 const int64_t slice_offset,
39- DenseTensor* output) {
40+ DenseTensor* x_grad,
41+ DenseTensor* value_grad) {
4042 int64_t numel = 0 ;
4143
4244 auto num_indices = index_dims.size ();
@@ -52,12 +54,18 @@ void GPUIndexElementwisePutGradKernel(
5254 std::array<int64_t *, 3 > strides_array;
5355 std::vector<int64_t > desired_shape;
5456 std::array<std::vector<int64_t >, 3 > strides_vec;
57+ std::vector<int64_t > value_dims;
58+ std::vector<int64_t > value_strides;
59+ if (value_grad) {
60+ value_dims = common::vectorize<int64_t >(value_grad->dims ());
61+ value_strides = common::vectorize<int64_t >(value_grad->strides ());
62+ }
5563
5664 funcs::IndexPutStride<3 >(input_dims,
5765 input_strides,
58- phi::SizeOf (output-> dtype ()),
59- std::vector< int64_t >() ,
60- std::vector< int64_t >() ,
66+ phi::SizeOf (out_grad. dtype ()),
67+ value_dims ,
68+ value_strides ,
6169 4 ,
6270 common::vectorize<int64_t >(index[0 ]->dims ()),
6371 common::vectorize<int64_t >(index[0 ]->strides ()),
@@ -78,158 +86,121 @@ void GPUIndexElementwisePutGradKernel(
7886 auto stream = dev_ctx.stream ();
7987
8088 using dtype = funcs::OpaqueType<sizeof (T)>;
81-
82- char * out_ptr = reinterpret_cast <char *>(output ->data <T>());
83- funcs::index_elementwise_kernel<nt, vt>
84- <<<grid, block, 0 , stream>>> (N, [=] __device__ (int idx) {
85- const auto offsets = offset_calc.get (idx);
86- char * const out_data = out_ptr + offsets[0 ] + slice_offset;
87-
88- int64_t offset = 0 ;
89+ if (!value_grad) {
90+ char * out_ptr = reinterpret_cast <char *>(x_grad ->data <T>());
91+ funcs::index_elementwise_kernel<nt, vt>
92+ <<<grid, block, 0 , stream>>> (N, [=] __device__ (int idx) {
93+ const auto offsets = offset_calc.get (idx);
94+ char * const out_data = out_ptr + offsets[0 ] + slice_offset;
95+
96+ int64_t offset = 0 ;
8997#pragma unroll
90- for (int i = 0 ; i < num_indices; i++) {
91- int64_t index =
92- *reinterpret_cast <int64_t *>(index_ptrs[i] + offsets[2 ]);
93- if (index < 0 ) {
94- index += sizes[i];
98+ for (int i = 0 ; i < num_indices; i++) {
99+ int64_t index =
100+ *reinterpret_cast <int64_t *>(index_ptrs[i] + offsets[2 ]);
101+ if (index < 0 ) {
102+ index += sizes[i];
103+ }
104+ offset += index * strides[i];
95105 }
96- offset += index * strides[i];
97- }
98- T num = T (0 );
99-
100- *reinterpret_cast <dtype*>(out_data + offset) =
101- *reinterpret_cast <dtype*>(&num);
102- });
103- }
104-
105- template <typename T>
106- __global__ void SetZeroElementwiseCudaKernel (
107- int64_t ** indices,
108- Array<int64_t , DDim::kMaxRank > stride,
109- Array<int64_t , DDim::kMaxRank > shape,
110- const int rank,
111- const int64_t numel,
112- T* out) {
113- int64_t idx =
114- static_cast <int64_t >(threadIdx .x ) +
115- static_cast <int64_t >(blockDim .x ) * static_cast <int64_t >(blockIdx .x );
116- if (idx >= numel) {
117- return ;
118- }
119-
120- int64_t cur_ix = 0 ;
121- int64_t offset = 0 ;
106+ T num = T (0 );
107+ *reinterpret_cast <dtype*>(out_data + offset) =
108+ *reinterpret_cast <dtype*>(&num);
109+ });
110+ } else if (!x_grad) {
111+ const char * out_ptr = reinterpret_cast <const char *>(out_grad.data <T>());
112+ char * value_ptr = reinterpret_cast <char *>(value_grad->data <T>());
113+ funcs::index_elementwise_kernel<nt, vt>
114+ <<<grid, block, 0 , stream>>> (N, [=] __device__ (int idx) {
115+ const auto offsets = offset_calc.get (idx);
116+ const char * const out_data = out_ptr + offsets[0 ] + slice_offset;
117+ char * const value_data = value_ptr + offsets[1 ];
118+
119+ int64_t offset = 0 ;
122120#pragma unroll
123- for (int i = 0 ; i < DDim::kMaxRank ; ++i) {
124- if (i >= rank) {
125- break ;
126- }
127- cur_ix = (static_cast <int64_t >(*(indices[i] + idx)));
128- if (cur_ix < 0 ) {
129- cur_ix += shape[i];
130- }
131- offset += stride[i] * cur_ix;
132- }
133-
134- *(out + offset) = 0 ;
135- }
136-
137- template <typename T>
138- __global__ void IndexElementwisePutGradCudaKernel (
139- const T* out_grad,
140- int64_t ** indices,
141- Array<int64_t , DDim::kMaxRank > stride,
142- Array<int64_t , DDim::kMaxRank > shape,
143- const int rank,
144- const int64_t numel,
145- T* value_grad) {
146- int64_t idx =
147- static_cast <int64_t >(threadIdx .x ) +
148- static_cast <int64_t >(blockDim .x ) * static_cast <int64_t >(blockIdx .x );
149- if (idx >= numel) {
150- return ;
151- }
152-
153- int64_t cur_ix = 0 ;
154- int64_t offset = 0 ;
121+ for (int i = 0 ; i < num_indices; i++) {
122+ int64_t index =
123+ *reinterpret_cast <int64_t *>(index_ptrs[i] + offsets[2 ]);
124+ if (index < 0 ) {
125+ index += sizes[i];
126+ }
127+ offset += index * strides[i];
128+ }
129+ *reinterpret_cast <dtype*>(value_data) =
130+ *reinterpret_cast <const dtype*>(out_data + offset);
131+ });
132+ } else {
133+ char * out_ptr = reinterpret_cast <char *>(x_grad->data <T>());
134+ char * value_ptr = reinterpret_cast <char *>(value_grad->data <T>());
135+ funcs::index_elementwise_kernel<nt, vt>
136+ <<<grid, block, 0 , stream>>> (N, [=] __device__ (int idx) {
137+ const auto offsets = offset_calc.get (idx);
138+ char * const out_data = out_ptr + offsets[0 ] + slice_offset;
139+ char * const value_data = value_ptr + offsets[1 ];
140+
141+ int64_t offset = 0 ;
155142#pragma unroll
156- for (int i = 0 ; i < DDim::kMaxRank ; ++i) {
157- if (i >= rank) {
158- break ;
159- }
160- cur_ix = (static_cast <int64_t >(*(indices[i] + idx)));
161- if (cur_ix < 0 ) {
162- cur_ix += shape[i];
163- }
164- offset += stride[i] * cur_ix;
143+ for (int i = 0 ; i < num_indices; i++) {
144+ int64_t index =
145+ *reinterpret_cast <int64_t *>(index_ptrs[i] + offsets[2 ]);
146+ if (index < 0 ) {
147+ index += sizes[i];
148+ }
149+ offset += index * strides[i];
150+ }
151+ T num = T (0 );
152+ *reinterpret_cast <dtype*>(value_data) =
153+ *reinterpret_cast <dtype*>(out_data + offset);
154+ *reinterpret_cast <dtype*>(out_data + offset) =
155+ *reinterpret_cast <dtype*>(&num);
156+ });
165157 }
166-
167- *(value_grad + idx) = *(out_grad + offset);
168158}
169159
170160template <typename T, typename Context>
171161void LaunchIndexElementwisePutGradCudaKernel (
172162 const Context& dev_ctx,
173- const std::vector<const DenseTensor*>& x_indices,
174163 const std::vector<const DenseTensor*>& indices,
175164 const DenseTensor& out_grad,
176- const int rank,
177165 const std::vector<int64_t >& input_dims,
178166 const std::vector<int64_t >& input_strides,
179167 const std::vector<int64_t >& index_dims,
180168 const std::vector<int64_t >& index_strides,
181169 const int64_t slice_offset,
182170 DenseTensor* value_grad,
183171 DenseTensor* x_grad) {
184- phi::Allocator::AllocationPtr indices_holder_1, indices_holder_2;
185- const auto & index_type = indices[0 ]->dtype ();
186- if (x_grad) {
172+ if (x_grad && !value_grad) {
187173 phi::Copy (dev_ctx, out_grad, dev_ctx.GetPlace (), false , x_grad);
188174
189175 GPUIndexElementwisePutGradKernel<T, int64_t >(dev_ctx,
190- x_indices,
176+ out_grad,
177+ indices,
191178 input_dims,
192179 input_strides,
193180 index_dims,
194181 index_strides,
195182 slice_offset,
196- x_grad);
197- }
198-
199- auto out_grad_dims = out_grad.dims ();
200- auto out_grad_stride = common::stride (out_grad_dims);
201-
202- Array<int64_t , DDim::kMaxRank > stride_array;
203- Array<int64_t , DDim::kMaxRank > shape_array;
204- for (int i = 0 ; i < rank; ++i) {
205- stride_array[i] = out_grad_stride[i];
206- shape_array[i] = out_grad_dims[i];
207- }
208-
209- const int64_t numel = indices[0 ]->numel ();
210- auto pd_indices = funcs::GetDevicePointerArray<int64_t , Context>(
211- dev_ctx, indices, &indices_holder_2);
212- auto config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, numel);
213-
214- if (value_grad) {
183+ x_grad,
184+ value_grad);
185+ } else if (value_grad) {
186+ if (x_grad) {
187+ phi::Copy (dev_ctx, out_grad, dev_ctx.GetPlace (), false , x_grad);
188+ }
215189 if (value_grad->numel () == 1 ) {
216190 DenseTensor tmp_value_grad (value_grad->dtype ());
217- tmp_value_grad.Resize (indices[0 ]->dims ());
218-
219- T* tmp_value_grad_data = dev_ctx.template Alloc <T>(&tmp_value_grad);
220- auto out_grad_data = out_grad.data <T>();
221-
222- IndexElementwisePutGradCudaKernel<T>
223- <<<config.block_per_grid,
224- config.thread_per_block,
225- 0 ,
226- dev_ctx.stream()>>> (out_grad_data,
227- pd_indices,
228- stride_array,
229- shape_array,
230- rank,
231- numel,
232- tmp_value_grad_data);
191+ tmp_value_grad.Resize (common::make_ddim (input_dims));
192+ dev_ctx.template Alloc <T>(&tmp_value_grad);
193+
194+ GPUIndexElementwisePutGradKernel<T, int64_t >(dev_ctx,
195+ out_grad,
196+ indices,
197+ input_dims,
198+ input_strides,
199+ index_dims,
200+ index_strides,
201+ slice_offset,
202+ x_grad,
203+ &tmp_value_grad);
233204
234205 std::vector<int > v_dims (tmp_value_grad.dims ().size ());
235206 std::iota (v_dims.begin (), v_dims.end (), 0 );
@@ -240,39 +211,33 @@ void LaunchIndexElementwisePutGradCudaKernel(
240211 value_grad->dtype (),
241212 false ,
242213 value_grad);
243- } else if (value_grad->numel () == indices[0 ]->numel ()) {
244- T* value_grad_data = dev_ctx.template Alloc <T>(value_grad);
245- auto out_grad_data = out_grad.data <T>();
246-
247- IndexElementwisePutGradCudaKernel<T>
248- <<<config.block_per_grid,
249- config.thread_per_block,
250- 0 ,
251- dev_ctx.stream()>>> (out_grad_data,
252- pd_indices,
253- stride_array,
254- shape_array,
255- rank,
256- numel,
257- value_grad_data);
214+ } else if (value_grad->dims () == common::make_ddim (input_dims)) {
215+ dev_ctx.template Alloc <T>(value_grad);
216+ GPUIndexElementwisePutGradKernel<T, int64_t >(dev_ctx,
217+ out_grad,
218+ indices,
219+ input_dims,
220+ input_strides,
221+ index_dims,
222+ index_strides,
223+ slice_offset,
224+ x_grad,
225+ value_grad);
258226 } else {
259227 DenseTensor tmp_value_grad (value_grad->dtype ());
260228 tmp_value_grad.Resize (common::make_ddim (input_dims));
261-
262- T* tmp_value_grad_data = dev_ctx.template Alloc <T>(&tmp_value_grad);
263- auto out_grad_data = out_grad.data <T>();
264-
265- IndexElementwisePutGradCudaKernel<T>
266- <<<config.block_per_grid,
267- config.thread_per_block,
268- 0 ,
269- dev_ctx.stream()>>> (out_grad_data,
270- pd_indices,
271- stride_array,
272- shape_array,
273- rank,
274- numel,
275- tmp_value_grad_data);
229+ dev_ctx.template Alloc <T>(&tmp_value_grad);
230+
231+ GPUIndexElementwisePutGradKernel<T, int64_t >(dev_ctx,
232+ out_grad,
233+ indices,
234+ input_dims,
235+ input_strides,
236+ index_dims,
237+ index_strides,
238+ slice_offset,
239+ x_grad,
240+ &tmp_value_grad);
276241
277242 std::vector<int64_t > after_dims =
278243 common::vectorize (tmp_value_grad.dims ());
@@ -335,33 +300,9 @@ void IndexElementwisePutGradKernel(
335300 return ;
336301 }
337302
338- auto bd_dim = funcs::BroadCastTensorsDims (indices);
339-
340- std::vector<int64_t > res_dim_v (common::vectorize (bd_dim));
341- std::vector<const phi::DenseTensor*> res_indices_v (x.dims ().size (), nullptr );
342- std::vector<DenseTensor> tmp_res_indices_v;
343- std::vector<DenseTensor> range_tensor_v;
344-
345- for (int i = indices.size (); i < x.dims ().size (); ++i) {
346- range_tensor_v.emplace_back (funcs::GetRangeCudaTensor<int64_t , Context>(
347- dev_ctx, x.dims ()[i], phi::DataType::INT64));
348- }
349-
350- funcs::DealWithIndices<T, Context>(dev_ctx,
351- x,
352- indices,
353- &res_indices_v,
354- &tmp_res_indices_v,
355- range_tensor_v,
356- bd_dim,
357- &res_dim_v);
358-
359- const int rank = x.dims ().size ();
360303 LaunchIndexElementwisePutGradCudaKernel<T, Context>(dev_ctx,
361304 indices,
362- res_indices_v,
363305 out_grad,
364- rank,
365306 input_dims,
366307 input_strides,
367308 index_dims,
0 commit comments