1919
2020namespace paddle {
2121namespace operators {
22+ // ///////////////////////////////
23+ // case inner_size == 1 BEGIN
24+ // ///////////////////////////////
2225
2326#define LAUNCH_WARP_FORWAR_COMPUTE (near_greater_power_of_two ) \
2427 case near_greater_power_of_two: \
@@ -141,6 +144,185 @@ void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size,
141144 break ;
142145 }
143146}
147+ // ///////////////////////////////
148+ // case inner_size == 1 END
149+ // ///////////////////////////////
150+
151+ // ///////////////////////////////
152+ // case inner_size != 1 BEGIN
153+ // ///////////////////////////////
154+
155+ template <typename T>
156+ struct Add {
157+ __device__ __forceinline__ T operator ()(T a, T b) const { return a + b; }
158+ };
159+ template <typename T>
160+ struct Max {
161+ __device__ __forceinline__ T operator ()(T a, T b) const {
162+ return a < b ? b : a;
163+ }
164+ };
165+
166+ template <typename T>
167+ __forceinline__ __device__ T BlockReduceMax (T *shared, T val) {
168+ // shared mem have #inner_size position offsets
169+ shared += threadIdx .y * blockDim .x ;
170+ __syncthreads ();
171+ // read every threads' max_value and put them to smem
172+ shared[threadIdx .x ] = val;
173+
174+ // block reduce operation
175+ int offset = blockDim .x / 2 ;
176+ Max<T> max;
177+ while (offset > 0 ) {
178+ __syncthreads ();
179+ if (threadIdx .x < offset) {
180+ shared[threadIdx .x ] =
181+ max (shared[threadIdx .x ], shared[threadIdx .x + offset]);
182+ }
183+ offset /= 2 ;
184+ }
185+ __syncthreads ();
186+ return shared[0 ];
187+ }
188+
189+ template <typename T>
190+ __forceinline__ __device__ T BlockReduceAdd (T *shared, T val) {
191+ shared += threadIdx .y * blockDim .x ;
192+ __syncthreads ();
193+ shared[threadIdx .x ] = val;
194+ int offset = blockDim .x / 2 ;
195+ Add<T> add;
196+ // Reduction in smem
197+ while (offset > 0 ) {
198+ __syncthreads ();
199+ if (threadIdx .x < offset) {
200+ shared[threadIdx .x ] =
201+ add (shared[threadIdx .x ], shared[threadIdx .x + offset]);
202+ }
203+ offset /= 2 ;
204+ }
205+ __syncthreads ();
206+ return shared[0 ];
207+ }
208+
209+ template <typename T, typename AccT>
210+ __global__ void LaunchLogSoftmaxForwardNotLastAxis (T *output, const T *input,
211+ int outer_size, int dim_size,
212+ int inner_size) {
213+ extern __shared__ unsigned char smem[];
214+ auto sdata = reinterpret_cast <AccT *>(smem);
215+
216+ const uint32_t outer_stride = inner_size * dim_size;
217+ const uint32_t dim_stride = inner_size;
218+
219+ for (uint32_t x_id = blockIdx .x /* * blockDim.x + threadIdx.x*/ ;
220+ x_id < outer_size; x_id += /* blockDim.x**/ gridDim .x ) {
221+ for (uint32_t y_id = blockIdx .y * blockDim .y + threadIdx .y ;
222+ y_id < inner_size; y_id += blockDim .y * gridDim .y ) {
223+ const uint32_t data_offset = x_id * outer_stride + y_id;
224+ if (blockDim .x > 1 ) {
225+ // 1. reduce max
226+ AccT max_value = -std::numeric_limits<AccT>::infinity ();
227+ for (uint32_t d = threadIdx .x ; d < dim_size; d += blockDim .x ) {
228+ const AccT value =
229+ static_cast <AccT>(input[data_offset + d * dim_stride]);
230+ max_value = Max<AccT>()(max_value, value);
231+ }
232+ max_value = BlockReduceMax<AccT>(sdata, max_value);
233+
234+ // 2. reduce sum
235+ AccT sum = 0 ;
236+ for (uint32_t d = threadIdx .x ; d < dim_size; d += blockDim .x )
237+ sum +=
238+ std::exp (static_cast <AccT>(input[data_offset + d * dim_stride]) -
239+ max_value);
240+ sum = BlockReduceAdd<AccT>(sdata, sum);
241+
242+ // 3. input-max-log_sum and store
243+ for (uint32_t d = threadIdx .x ; d < dim_size; d += blockDim .x ) {
244+ output[data_offset + d * dim_stride] = static_cast <T>(
245+ static_cast <AccT>(input[data_offset + d * dim_stride]) -
246+ max_value - std::log (sum));
247+ }
248+ } else {
249+ // 1. reduce max
250+ AccT max_value = -std::numeric_limits<AccT>::infinity ();
251+ for (uint32_t d = threadIdx .x ; d < dim_size; ++d) {
252+ const AccT value =
253+ static_cast <AccT>(input[data_offset + d * dim_stride]);
254+ uint32_t id = data_offset + d * dim_stride;
255+ max_value = Max<AccT>()(max_value, value);
256+ }
257+ AccT sum = 0 ;
258+ for (uint32_t d = threadIdx .x ; d < dim_size; ++d)
259+ sum +=
260+ std::exp (static_cast <AccT>(input[data_offset + d * dim_stride]) -
261+ max_value);
262+ for (uint32_t d = threadIdx .x ; d < dim_size; ++d) {
263+ output[data_offset + d * dim_stride] = static_cast <T>(
264+ static_cast <AccT>(input[data_offset + d * dim_stride]) -
265+ max_value - std::log (sum));
266+ }
267+ }
268+ }
269+ }
270+ }
271+
272+ inline dim3 GetGridSize (dim3 block, uint32_t max_active_blocks,
273+ uint64_t outer_size, uint64_t dim_size,
274+ uint64_t inner_size) {
275+ // First, tile as many blocks as we can over the y axis
276+ uint32_t inner_blocks = (inner_size + block.y - 1 ) / block.y ;
277+ if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks;
278+ // Fill the x axis with as many blocks as we can fit (a little more is ok too)
279+ uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1 ) / inner_blocks;
280+ if (outer_blocks > outer_size) outer_blocks = outer_size;
281+ return dim3 (outer_blocks, inner_blocks);
282+ }
283+
284+ const int max_threads = 1024 ;
285+
286+ inline dim3 GetBlockSize (uint64_t outer_size, uint64_t dim_size,
287+ uint64_t inner_size) {
288+ uint32_t inner_threads = inner_size;
289+ inner_threads = std::min (inner_threads, static_cast <uint32_t >(max_threads));
290+ uint32_t dim_threads = 1 ;
291+ if (inner_threads <= 64 && dim_size >= 64 ) {
292+ while (inner_threads * dim_threads <= max_threads &&
293+ dim_threads <= dim_size)
294+ dim_threads *= 2 ;
295+ dim_threads /= 2 ;
296+ }
297+ return dim3 (dim_threads, inner_threads);
298+ }
299+
300+ template <typename T, typename Kernel>
301+ void ComputeLaunchConfigure (Kernel k, uint64_t outer_size, uint64_t dim_size,
302+ uint64_t inner_size, dim3 &grid, dim3 &block,
303+ uint32_t &shared_mem, uint32_t num_sm) {
304+ // get block config
305+ block = GetBlockSize (outer_size, dim_size, inner_size);
306+ // get num threads in a block
307+ uint32_t block_threads = block.x * block.y ;
308+ // init shared_mem
309+ shared_mem = block.x == 1 ? 0 : block_threads * sizeof (T);
310+ int max_active_blocks;
311+ #ifdef PADDLE_WITH_HIP
312+ PADDLE_ENFORCE_CUDA_SUCCESS (hipOccupancyMaxActiveBlocksPerMultiprocessor (
313+ &max_active_blocks, k, block_threads, shared_mem));
314+ #else
315+ PADDLE_ENFORCE_CUDA_SUCCESS (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
316+ &max_active_blocks, k, block_threads, shared_mem));
317+ #endif
318+ max_active_blocks *= num_sm;
319+ grid =
320+ GetGridSize (block, max_active_blocks, outer_size, dim_size, inner_size);
321+ }
322+
323+ // ///////////////////////////////
324+ // case inner_size != 1 END
325+ // ///////////////////////////////
144326
145327template <typename T>
146328class LogSoftmaxKernel <platform::CUDADeviceContext, T>
@@ -164,14 +346,25 @@ class LogSoftmaxKernel<platform::CUDADeviceContext, T>
164346 }
165347 int outer_size = SizeToAxis (axis, x->dims ());
166348 gpuStream_t stream = context.cuda_device_context ().stream ();
349+ uint32_t num_sm = context.cuda_device_context ().GetSMCount ();
167350
168351 if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof (T) <= 4096 ) {
169352 LaunchSoftmaxForwardForLastAxis<T, MPDType>(output_data, input_data,
170353 dim_size, outer_size, stream);
171354 } else {
172- LogSoftmaxFunctor<platform::CUDADeviceContext, T>()(
173- context.template device_context <platform::CUDADeviceContext>(), x,
174- out, axis);
355+ // inner_size != 1
356+ uint32_t shared_mem;
357+ dim3 grid;
358+ dim3 block;
359+
360+ ComputeLaunchConfigure<MPDType>(
361+ &LaunchLogSoftmaxForwardNotLastAxis<T, MPDType>, outer_size, dim_size,
362+ inner_size, grid, block, shared_mem, num_sm);
363+
364+ LaunchLogSoftmaxForwardNotLastAxis<
365+ T,
366+ MPDType /* , MinusMaxAndLogsum*/ ><<<grid, block, shared_mem, stream>>> (
367+ output_data, input_data, outer_size, dim_size, inner_size);
175368 }
176369 }
177370};
@@ -209,8 +402,8 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
209402 grad_output_register[iter] = static_cast <AccT>(
210403 grad_output[batch_id * element_count + element_index]);
211404 } else {
212- output_register[iter] = AccT (0 );
213- grad_output_register[iter] = AccT (0 );
405+ output_register[iter] = static_cast < AccT> (0 );
406+ grad_output_register[iter] = static_cast < AccT> (0 );
214407 }
215408 }
216409
0 commit comments