@@ -83,6 +83,7 @@ struct LoaderTypeClassifier {
8383};
8484
8585#ifndef PADDLE_WITH_XPU_KP
86+ // Common broadcast/elementwise Loader.
8687template <typename T, int VecSize, int Arity, bool IsBoundary, int LoadType>
8788struct BroadcastDataLoader {
8889 __device__ __forceinline__ void operator ()(
@@ -107,6 +108,7 @@ struct BroadcastDataLoader {
107108 }
108109};
109110
111+ // Scalar elementwise Loader with consideration of IsBoundary.
110112template <typename T, int VecSize, int Arity>
111113struct BroadcastDataLoader <T, VecSize, Arity, true , kElementwise > {
112114 __device__ __forceinline__ void operator ()(
@@ -117,17 +119,12 @@ struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
117119 const int block_offset,
118120 const int num,
119121 const uint32_t numel) {
120- #pragma unroll
121- for (int i = 0 ; i < Arity; ++i) {
122- #pragma unroll
123- kps::Init<T, VecSize>(args[i], static_cast <T>(1 ));
124- }
125-
126122 int thread_offset = threadIdx.x * VecSize + block_offset;
127123#pragma unroll
128124 for (int i = 0 ; i < Arity; ++i) {
129125#pragma unroll
130126 for (int idx = 0 ; idx < VecSize; ++idx) {
127+ args[i][idx] = static_cast <T>(1 );
131128 int index = thread_offset + idx;
132129 if (index < numel) {
133130 args[i][idx] = ins[i][index];
@@ -137,6 +134,7 @@ struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
137134 }
138135};
139136
137+ // Vectorized elementwise Loader without consideration of IsBoundary.
140138template <typename T, int VecSize, int Arity>
141139struct BroadcastDataLoader <T, VecSize, Arity, false , kElementwise > {
142140 __device__ __forceinline__ void operator ()(
@@ -164,6 +162,7 @@ struct BroadcastDataLoader<T, VecSize, Arity, false, kElementwise> {
164162 }
165163};
166164
165+ // Common broadcast data loader.
167166template <typename T, int VecSize, int Arity, bool IsBoundary>
168167struct BroadcastDataLoader <T, VecSize, Arity, IsBoundary, kBroadcast > {
169168 __device__ __forceinline__ void operator ()(
@@ -405,11 +404,10 @@ void LaunchBroadcastKernel(
405404 auto gpu_config =
406405 phi::backends::gpu::GetGpuLaunchConfig1D (ctx, numel, VecSize);
407406 auto stream = ctx.stream ();
408- auto threads = gpu_config.thread_per_block ;
407+ auto threads = gpu_config.GetBlockSize () ;
409408 auto blocks = gpu_config.block_per_grid ;
410- int main_offset = (numel / (VecSize * gpu_config.GetBlockSize ())) * VecSize *
411- gpu_config.GetBlockSize ();
412- int tail_tid = numel % (VecSize * gpu_config.GetBlockSize ());
409+ int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
410+ int tail_tid = numel % (VecSize * threads);
413411
414412 if (loader_classifier.all_elementwise ) {
415413 VectorizedBroadcastKernel<Func,
0 commit comments