Skip to content

Commit c260064

Browse files
committed
Log_softmax backward case: axis!=-1
1 parent d521199 commit c260064

File tree

1 file changed

+198
-5
lines changed

1 file changed

+198
-5
lines changed

paddle/fluid/operators/log_softmax_op.cu

Lines changed: 198 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
namespace paddle {
2121
namespace operators {
22+
/////////////////////////////////
23+
// case inner_size == 1 BEGIN
24+
/////////////////////////////////
2225

2326
#define LAUNCH_WARP_FORWAR_COMPUTE(near_greater_power_of_two) \
2427
case near_greater_power_of_two: \
@@ -141,6 +144,185 @@ void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size,
141144
break;
142145
}
143146
}
147+
/////////////////////////////////
148+
// case inner_size == 1 END
149+
/////////////////////////////////
150+
151+
/////////////////////////////////
152+
// case inner_size != 1 BEGIN
153+
/////////////////////////////////
154+
155+
template <typename T>
156+
struct Add {
157+
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
158+
};
159+
template <typename T>
160+
struct Max {
161+
__device__ __forceinline__ T operator()(T a, T b) const {
162+
return a < b ? b : a;
163+
}
164+
};
165+
166+
template <typename T>
167+
__forceinline__ __device__ T BlockReduceMax(T *shared, T val) {
168+
// shared mem have #inner_size position offsets
169+
shared += threadIdx.y * blockDim.x;
170+
__syncthreads();
171+
// read every threads' max_value and put them to smem
172+
shared[threadIdx.x] = val;
173+
174+
// block reduce operation
175+
int offset = blockDim.x / 2;
176+
Max<T> max;
177+
while (offset > 0) {
178+
__syncthreads();
179+
if (threadIdx.x < offset) {
180+
shared[threadIdx.x] =
181+
max(shared[threadIdx.x], shared[threadIdx.x + offset]);
182+
}
183+
offset /= 2;
184+
}
185+
__syncthreads();
186+
return shared[0];
187+
}
188+
189+
template <typename T>
190+
__forceinline__ __device__ T BlockReduceAdd(T *shared, T val) {
191+
shared += threadIdx.y * blockDim.x;
192+
__syncthreads();
193+
shared[threadIdx.x] = val;
194+
int offset = blockDim.x / 2;
195+
Add<T> add;
196+
// Reduction in smem
197+
while (offset > 0) {
198+
__syncthreads();
199+
if (threadIdx.x < offset) {
200+
shared[threadIdx.x] =
201+
add(shared[threadIdx.x], shared[threadIdx.x + offset]);
202+
}
203+
offset /= 2;
204+
}
205+
__syncthreads();
206+
return shared[0];
207+
}
208+
209+
template <typename T, typename AccT>
210+
__global__ void LaunchLogSoftmaxForwardNotLastAxis(T *output, const T *input,
211+
int outer_size, int dim_size,
212+
int inner_size) {
213+
extern __shared__ unsigned char smem[];
214+
auto sdata = reinterpret_cast<AccT *>(smem);
215+
216+
const uint32_t outer_stride = inner_size * dim_size;
217+
const uint32_t dim_stride = inner_size;
218+
219+
for (uint32_t x_id = blockIdx.x /** blockDim.x + threadIdx.x*/;
220+
x_id < outer_size; x_id += /*blockDim.x**/ gridDim.x) {
221+
for (uint32_t y_id = blockIdx.y * blockDim.y + threadIdx.y;
222+
y_id < inner_size; y_id += blockDim.y * gridDim.y) {
223+
const uint32_t data_offset = x_id * outer_stride + y_id;
224+
if (blockDim.x > 1) {
225+
// 1. reduce max
226+
AccT max_value = -std::numeric_limits<AccT>::infinity();
227+
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
228+
const AccT value =
229+
static_cast<AccT>(input[data_offset + d * dim_stride]);
230+
max_value = Max<AccT>()(max_value, value);
231+
}
232+
max_value = BlockReduceMax<AccT>(sdata, max_value);
233+
234+
// 2. reduce sum
235+
AccT sum = 0;
236+
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
237+
sum +=
238+
std::exp(static_cast<AccT>(input[data_offset + d * dim_stride]) -
239+
max_value);
240+
sum = BlockReduceAdd<AccT>(sdata, sum);
241+
242+
// 3. input-max-log_sum and store
243+
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
244+
output[data_offset + d * dim_stride] = static_cast<T>(
245+
static_cast<AccT>(input[data_offset + d * dim_stride]) -
246+
max_value - std::log(sum));
247+
}
248+
} else {
249+
// 1. reduce max
250+
AccT max_value = -std::numeric_limits<AccT>::infinity();
251+
for (uint32_t d = threadIdx.x; d < dim_size; ++d) {
252+
const AccT value =
253+
static_cast<AccT>(input[data_offset + d * dim_stride]);
254+
uint32_t id = data_offset + d * dim_stride;
255+
max_value = Max<AccT>()(max_value, value);
256+
}
257+
AccT sum = 0;
258+
for (uint32_t d = threadIdx.x; d < dim_size; ++d)
259+
sum +=
260+
std::exp(static_cast<AccT>(input[data_offset + d * dim_stride]) -
261+
max_value);
262+
for (uint32_t d = threadIdx.x; d < dim_size; ++d) {
263+
output[data_offset + d * dim_stride] = static_cast<T>(
264+
static_cast<AccT>(input[data_offset + d * dim_stride]) -
265+
max_value - std::log(sum));
266+
}
267+
}
268+
}
269+
}
270+
}
271+
272+
inline dim3 GetGridSize(dim3 block, uint32_t max_active_blocks,
273+
uint64_t outer_size, uint64_t dim_size,
274+
uint64_t inner_size) {
275+
// First, tile as many blocks as we can over the y axis
276+
uint32_t inner_blocks = (inner_size + block.y - 1) / block.y;
277+
if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks;
278+
// Fill the x axis with as many blocks as we can fit (a little more is ok too)
279+
uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
280+
if (outer_blocks > outer_size) outer_blocks = outer_size;
281+
return dim3(outer_blocks, inner_blocks);
282+
}
283+
284+
const int max_threads = 1024;
285+
286+
inline dim3 GetBlockSize(uint64_t outer_size, uint64_t dim_size,
287+
uint64_t inner_size) {
288+
uint32_t inner_threads = inner_size;
289+
inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads));
290+
uint32_t dim_threads = 1;
291+
if (inner_threads <= 64 && dim_size >= 64) {
292+
while (inner_threads * dim_threads <= max_threads &&
293+
dim_threads <= dim_size)
294+
dim_threads *= 2;
295+
dim_threads /= 2;
296+
}
297+
return dim3(dim_threads, inner_threads);
298+
}
299+
300+
template <typename T, typename Kernel>
301+
void ComputeLaunchConfigure(Kernel k, uint64_t outer_size, uint64_t dim_size,
302+
uint64_t inner_size, dim3 &grid, dim3 &block,
303+
uint32_t &shared_mem, uint32_t num_sm) {
304+
// get block config
305+
block = GetBlockSize(outer_size, dim_size, inner_size);
306+
// get num threads in a block
307+
uint32_t block_threads = block.x * block.y;
308+
// init shared_mem
309+
shared_mem = block.x == 1 ? 0 : block_threads * sizeof(T);
310+
int max_active_blocks;
311+
#ifdef PADDLE_WITH_HIP
312+
PADDLE_ENFORCE_CUDA_SUCCESS(hipOccupancyMaxActiveBlocksPerMultiprocessor(
313+
&max_active_blocks, k, block_threads, shared_mem));
314+
#else
315+
PADDLE_ENFORCE_CUDA_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
316+
&max_active_blocks, k, block_threads, shared_mem));
317+
#endif
318+
max_active_blocks *= num_sm;
319+
grid =
320+
GetGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
321+
}
322+
323+
/////////////////////////////////
324+
// case inner_size != 1 END
325+
/////////////////////////////////
144326

145327
template <typename T>
146328
class LogSoftmaxKernel<platform::CUDADeviceContext, T>
@@ -164,14 +346,25 @@ class LogSoftmaxKernel<platform::CUDADeviceContext, T>
164346
}
165347
int outer_size = SizeToAxis(axis, x->dims());
166348
gpuStream_t stream = context.cuda_device_context().stream();
349+
uint32_t num_sm = context.cuda_device_context().GetSMCount();
167350

168351
if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
169352
LaunchSoftmaxForwardForLastAxis<T, MPDType>(output_data, input_data,
170353
dim_size, outer_size, stream);
171354
} else {
172-
LogSoftmaxFunctor<platform::CUDADeviceContext, T>()(
173-
context.template device_context<platform::CUDADeviceContext>(), x,
174-
out, axis);
355+
// inner_size != 1
356+
uint32_t shared_mem;
357+
dim3 grid;
358+
dim3 block;
359+
360+
ComputeLaunchConfigure<MPDType>(
361+
&LaunchLogSoftmaxForwardNotLastAxis<T, MPDType>, outer_size, dim_size,
362+
inner_size, grid, block, shared_mem, num_sm);
363+
364+
LaunchLogSoftmaxForwardNotLastAxis<
365+
T,
366+
MPDType /*, MinusMaxAndLogsum*/><<<grid, block, shared_mem, stream>>>(
367+
output_data, input_data, outer_size, dim_size, inner_size);
175368
}
176369
}
177370
};
@@ -209,8 +402,8 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
209402
grad_output_register[iter] = static_cast<AccT>(
210403
grad_output[batch_id * element_count + element_index]);
211404
} else {
212-
output_register[iter] = AccT(0);
213-
grad_output_register[iter] = AccT(0);
405+
output_register[iter] = static_cast<AccT>(0);
406+
grad_output_register[iter] = static_cast<AccT>(0);
214407
}
215408
}
216409

0 commit comments

Comments
 (0)