@@ -92,6 +92,12 @@ class ReduceMin {
9292};
9393static ReduceMin reduce_min;
9494
95+ __global__ void CudaMemsetAsync (int * dest, int value, size_t size) {
96+ int tid = threadIdx .x + blockIdx .x * blockDim .x ;
97+ if (tid * sizeof (int ) >= size) return ;
98+ dest[tid] = value;
99+ }
100+
95101template <typename tensor_t ,
96102 typename index_t ,
97103 typename func_t ,
@@ -112,13 +118,6 @@ __global__ void ScatterAssignGPUKernel(tensor_t* self_data,
112118 int * thread_ids) {
113119 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
114120 if (tid >= numel) return ;
115-
116- if (tid == 0 ) {
117- for (int i = 0 ; i < numel_data; i++) {
118- thread_ids[i] = 0 ;
119- }
120- }
121- __syncthreads ();
122121 int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
123122 // squeezed from the N layers loop.
124123 /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
@@ -267,16 +266,6 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
267266 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
268267 if (tid >= numel) return ;
269268
270- if (tid == 0 ) {
271- for (int i = 0 ; i < numel_data; i++) {
272- shared_mem[i] = 0 ; // thread_id
273- if (include_self)
274- shared_mem[numel_data + i] = 1 ; // reduce size
275- else
276- shared_mem[numel_data + i] = 0 ;
277- }
278- }
279- __syncthreads ();
280269 int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
281270 // squeezed from the N layers loop.
282271 /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
@@ -384,6 +373,7 @@ struct gpu_gather_scatter_functor {
384373 int * shared_mem;
385374 cudaMallocAsync (
386375 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
376+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
387377 ScatterAssignGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
388378 <<<grid, block, 0 , stream>>> (self_data,
389379 dim,
@@ -405,6 +395,14 @@ struct gpu_gather_scatter_functor {
405395 int * shared_mem;
406396 cudaMallocAsync (
407397 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
398+ cudaMemsetAsync (shared_mem, 0 , sizeof (int ) * self_size, stream);
399+ if (include_self) {
400+ int64_t grid_memset = (self_size * 2 + block - 1 ) / block;
401+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
402+ shared_mem, 1 , shared_mem_size);
403+ } else {
404+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
405+ }
408406 ScatterMeanGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
409407 <<<grid, block, 0 , stream>>> (self_data,
410408 dim,
@@ -429,6 +427,9 @@ struct gpu_gather_scatter_functor {
429427 shared_mem_size = sizeof (int ) * self_size;
430428 cudaMallocAsync (
431429 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
430+ int64_t grid_memset = (self_size + block - 1 ) / block;
431+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
432+ shared_mem, index_size + 1 , shared_mem_size);
432433 }
433434 GatherScatterGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
434435 <<<grid, block, shared_mem_size, stream>>> (self_data,
@@ -640,12 +641,6 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data,
640641 int * thread_ids) {
641642 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
642643 if (tid >= numel) return ;
643- if (tid == 0 ) {
644- for (int i = 0 ; i < numel_grad; i++) {
645- thread_ids[i] = 0 ;
646- }
647- }
648- __syncthreads ();
649644 int64_t i, j, k;
650645 i = tid / (select_dim_size * outer_dim_size);
651646 int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -682,13 +677,6 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data,
682677 int * shared_mem) {
683678 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
684679 if (tid >= numel) return ;
685-
686- if (tid == 0 ) {
687- for (int i = 0 ; i < numel_grad; i++) {
688- shared_mem[i] = 1 ; // number of elements
689- }
690- }
691- __syncthreads ();
692680 int64_t i, j, k;
693681 i = tid / (select_dim_size * outer_dim_size);
694682 int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -762,6 +750,7 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
762750 int * shared_mem;
763751 cudaMallocAsync (
764752 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
753+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
765754 ScatterMulInputGradGPUKernel<tensor_t , index_t >
766755 <<<grid, block, 0 , stream>>> (grad_data,
767756 dim,
@@ -781,6 +770,9 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
781770 int * shared_mem;
782771 cudaMallocAsync (
783772 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
773+ int64_t grid_memset = (grad_size + block - 1 ) / block;
774+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
775+ shared_mem, 1 , shared_mem_size);
784776 ScatterMinMaxInputGradGPUKernel<tensor_t , index_t >
785777 <<<grid, block, 0 , stream>>> (grad_data,
786778 dim,
@@ -816,13 +808,6 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data,
816808 int * shared_mem) {
817809 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
818810 if (tid >= numel) return ;
819- if (tid == 0 ) {
820- for (int i = 0 ; i < numel_grad; i++) {
821- shared_mem[i] = 0 ; // thread_ids
822- shared_mem[numel_grad + i] = 1 ; // number of elements
823- }
824- }
825- __syncthreads ();
826811 int64_t i, j, k;
827812 i = tid / (select_dim_size * outer_dim_size);
828813 int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -879,6 +864,10 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self,
879864 int * shared_mem;
880865 cudaMallocAsync (
881866 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
867+ cudaMemsetAsync (shared_mem, 0 , sizeof (int ) * grad_size, stream);
868+ int64_t grid_memset = (grad_size + block - 1 ) / block;
869+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
870+ shared_mem + grad_size, 1 , sizeof (int ) * grad_size);
882871 ScatterMeanInputGradGPUKernel<tensor_t , index_t >
883872 <<<grid, block, 0 , stream>>> (grad_data,
884873 dim,
@@ -910,12 +899,6 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data,
910899 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
911900 if (tid >= numel) return ;
912901
913- if (tid == 0 ) {
914- for (int i = 0 ; i < numel_data; i++) {
915- thread_ids[i] = 0 ;
916- }
917- }
918- __syncthreads ();
919902 int64_t i, j, k;
920903 i = tid / (select_dim_size * outer_dim_size);
921904 int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -975,6 +958,7 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self,
975958 int * shared_mem;
976959 cudaMallocAsync (
977960 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
961+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
978962 ScatterValueGradGPUKernel<tensor_t , index_t >
979963 <<<grid, block, 0 , stream>>> (grad_data,
980964 dim,
@@ -1005,20 +989,10 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data,
1005989 int64_t outer_dim_size_grad,
1006990 int64_t numel,
1007991 int64_t numel_self,
1008- bool include_self,
1009992 int * shared_mem) {
1010993 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
1011994 if (tid >= numel) return ;
1012995
1013- if (tid == 0 ) {
1014- for (int i = 0 ; i < numel_self; i++) {
1015- if (include_self)
1016- shared_mem[i] = 1 ; // number of elements
1017- else
1018- shared_mem[i] = 0 ;
1019- }
1020- }
1021- __syncthreads ();
1022996 int64_t i, j, k;
1023997 i = tid / (select_dim_size * outer_dim_size);
1024998 int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -1114,6 +1088,13 @@ void gpu_scatter_add_mean_value_grad_kernel(
11141088 int * shared_mem;
11151089 cudaMallocAsync (
11161090 reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
1091+ if (include_self) {
1092+ int64_t grid_memset = (self_size + block - 1 ) / block;
1093+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
1094+ shared_mem, 1 , shared_mem_size);
1095+ } else {
1096+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
1097+ }
11171098 ScatterMeanValueGradGPUKernel<tensor_t , index_t >
11181099 <<<grid, block, 0 , stream>>> (grad_data,
11191100 dim,
@@ -1127,7 +1108,6 @@ void gpu_scatter_add_mean_value_grad_kernel(
11271108 outer_dim_size_grad,
11281109 index_size,
11291110 self_size,
1130- include_self,
11311111 shared_mem);
11321112 cudaFreeAsync (shared_mem, stream);
11331113 } else if (reduce == " add" ) {
0 commit comments