@@ -22,27 +22,27 @@ using Tensor = framework::Tensor;
2222namespace {
2323template <typename T>
2424__global__ void CrossEntropyGrad (T* logit_grad, const int64_t * labels,
25- const int n, const int d, const int remain ,
26- const int ignore_index) {
27- CUDA_KERNEL_LOOP (index, n * remain) {
28- int idx_n = index / remain;
29- int idx_remain = index % remain;
30- int tmp = labels[index];
25+ const int64_t n, const int64_t d ,
26+ const int64_t remain, const int ignore_index) {
27+ CUDA_KERNEL_LOOP_TYPE (index, n * remain, int64_t ) {
28+ int64_t idx_n = index / remain;
29+ int64_t idx_remain = index % remain;
30+ int64_t tmp = labels[index];
3131 if (ignore_index != tmp) {
32- int idx = idx_n * d + tmp * remain + idx_remain;
32+ int64_t idx = idx_n * d + tmp * remain + idx_remain;
3333 logit_grad[idx] -= static_cast <T>(1 .);
3434 }
3535 }
3636}
3737
3838template <typename T>
39- __global__ void Scale (T* logit_grad, const T* loss_grad, const int num,
40- const int d, const int remain, const int64_t * labels ,
41- const int ignore_index) {
42- CUDA_KERNEL_LOOP (index, num) {
43- int idx_n = index / d;
44- int idx_remain = index % remain;
45- int idx_lbl = idx_n * remain + idx_remain;
39+ __global__ void Scale (T* logit_grad, const T* loss_grad, const int64_t num,
40+ const int64_t d, const int64_t remain ,
41+ const int64_t * labels, const int ignore_index) {
42+ CUDA_KERNEL_LOOP_TYPE (index, num, int64_t ) {
43+ int64_t idx_n = index / d;
44+ int64_t idx_remain = index % remain;
45+ int64_t idx_lbl = idx_n * remain + idx_remain;
4646 if (labels[idx_lbl] == ignore_index) {
4747 logit_grad[index] = static_cast <T>(0 .);
4848 } else {
@@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
5454template <typename T>
5555__global__ void SoftCrossEntropyGradientKernel (T* logit_grad,
5656 const T* loss_grad,
57- const T* labels, const int n,
58- const int d, const int remain) {
59- int ids = blockIdx .x * blockDim .x + threadIdx .x ;
57+ const T* labels, const int64_t n,
58+ const int64_t d,
59+ const int64_t remain) {
60+ int64_t ids = blockIdx .x * blockDim .x + threadIdx .x ;
6061 if (ids < n * d) {
61- int idx_n = ids / d;
62- int idx_remain = ids % remain;
63- int idx_loss = idx_n * remain + idx_remain;
62+ int64_t idx_n = ids / d;
63+ int64_t idx_remain = ids % remain;
64+ int64_t idx_loss = idx_n * remain + idx_remain;
6465 logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
6566 }
6667}
@@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
132133// This kernel is used to calculate the max element of each row
133134template <typename T, int BlockDim>
134135static __global__ void RowReductionForMax (const T* logits_data, T* max_data,
135- int d, int axis_dim) {
136+ int64_t d, int axis_dim) {
136137 __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
137138
138139 // logits_data view as [n, axis_dim, remain]
139140 // max_data view as [n, 1, remain]
140141 // blockDim = n * remain, split blockIdx to idx_n and idx_remain
141- int remain = d / axis_dim;
142- int idx_n = blockIdx .x / remain;
143- int idx_remain = blockIdx .x % remain;
144- int beg_idx = idx_n * d + threadIdx .x * remain + idx_remain;
145- int end_idx = (idx_n + 1 ) * d;
142+ int64_t remain = d / axis_dim;
143+ int64_t idx_n = blockIdx .x / remain;
144+ int64_t idx_remain = blockIdx .x % remain;
145+ int64_t beg_idx = idx_n * d + threadIdx .x * remain + idx_remain;
146+ int64_t end_idx = (idx_n + 1 ) * d;
146147
147- int step = BlockDim * remain;
148+ int64_t step = BlockDim * remain;
148149 T cur_max = logits_data[beg_idx];
149150 beg_idx += step;
150151 while (beg_idx < end_idx) {
@@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
162163// Make sure that BlockDim <= axis_dim
163164template <typename T, int BlockDim, bool CalculateLogSoftmax = false >
164165static __global__ void RowReductionForDiffMaxSum (const T* logits_data,
165- T* max_data, T* softmax, int d,
166- int axis_dim) {
166+ T* max_data, T* softmax,
167+ int64_t d, int axis_dim) {
167168 __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
168169
169170 // logits, softmax data view as [n, axis_dim, remain]
170171 // max_data view as [n, 1, remain]
171172 // blockDim = n * remain, split blockIdx to idx_n and idx_remain
172- int remain = d / axis_dim;
173- int idx_n = blockIdx .x / remain;
174- int idx_remain = blockIdx .x % remain;
175- int beg_idx = idx_n * d + threadIdx .x * remain + idx_remain;
176- int end_idx = (idx_n + 1 ) * d;
173+ int64_t remain = d / axis_dim;
174+ int64_t idx_n = blockIdx .x / remain;
175+ int64_t idx_remain = blockIdx .x % remain;
176+ int64_t beg_idx = idx_n * d + threadIdx .x * remain + idx_remain;
177+ int64_t end_idx = (idx_n + 1 ) * d;
177178
178179 auto block_max = max_data[blockIdx .x ];
179- int step = BlockDim * remain;
180+ int64_t step = BlockDim * remain;
180181
181182 // In numeric stable mode softmax_with_loss, we calc loss with
182183 // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
@@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
216217// Make sure that BlockDim <= axis_dim
217218template <typename T, int BlockDim>
218219static __global__ void RowReductionForSoftmaxAndCrossEntropy (
219- const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
220- int axis_dim) {
220+ const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
221+ int64_t d, int axis_dim) {
221222 __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
222223
223224 // logits, softmax, labels data view as [n, axis_dim, remain]
224225 // loss_data view as [n, 1, remain]
225226 // blockDim = n * remain, split blockIdx to idx_n and idx_remain
226- int remain = d / axis_dim;
227- int idx_n = blockIdx .x / remain;
228- int idx_remain = blockIdx .x % remain;
229- int beg_idx = idx_n * d + threadIdx .x * remain + idx_remain;
230- int end_idx = (idx_n + 1 ) * d;
227+ int64_t remain = d / axis_dim;
228+ int64_t idx_n = blockIdx .x / remain;
229+ int64_t idx_remain = blockIdx .x % remain;
230+ int64_t beg_idx = idx_n * d + threadIdx .x * remain + idx_remain;
231+ int64_t end_idx = (idx_n + 1 ) * d;
231232
232233 // log_diff_max_sum shares memory with loss
233234 auto block_log_diff_max_sum = loss_data[blockIdx .x ];
234235 auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
235236 softmax[beg_idx] = exp_on_device (tmp);
236237 auto loss = -labels_data[beg_idx] * tmp;
237- int step = BlockDim * remain;
238+ int64_t step = BlockDim * remain;
238239 beg_idx += step;
239240 while (beg_idx < end_idx) {
240241 tmp = softmax[beg_idx] - block_log_diff_max_sum;
@@ -251,21 +252,26 @@ template <typename T>
251252struct HardLabelSoftmaxWithCrossEntropyFunctor {
252253 public:
253254 HardLabelSoftmaxWithCrossEntropyFunctor (const int64_t * labels, T* loss,
254- T* log_softmax, int d, int axis_dim)
255+ T* log_softmax, int64_t d,
256+ int axis_dim)
255257 : labels_(labels),
256258 loss_ (loss),
257259 log_softmax_(log_softmax),
258260 d_(d),
259261 axis_dim_(axis_dim) {}
260262
261- __device__ void operator ()(int idx) const {
263+ __device__ void operator ()(int64_t idx) const {
262264 // logits view as [n, axis_dim, remain], where d = axis_dim * remain
263- int remain = d_ / axis_dim_;
264- int idx_n = idx / d_;
265- int idx_axis = (idx % d_) / remain;
266- int idx_remain = idx % remain;
265+ int64_t remain = d_ / axis_dim_;
266+ int64_t idx_n = idx / d_;
267+ int64_t idx_axis = (idx % d_) / remain;
268+ int64_t idx_remain = idx % remain;
267269 // labels, loss view as [n, remain]
268- int idx_lbl = idx_n * remain + idx_remain;
270+ int64_t idx_lbl = idx_n * remain + idx_remain;
271+ PADDLE_ENFORCE (labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_,
272+ " The value of label[%ld] expected >= 0 and < %ld,"
273+ " but got %ld. Please check input value." ,
274+ idx_lbl, d_, labels_[idx_lbl]);
269275 // It also would ignore labels not in range(class_num).
270276 if (idx_axis != labels_[idx_lbl]) {
271277 log_softmax_[idx] = exp_on_device (log_softmax_[idx]);
@@ -280,7 +286,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
280286 const int64_t * labels_;
281287 T* loss_;
282288 T* log_softmax_;
283- int d_;
289+ int64_t d_;
284290 int axis_dim_;
285291};
286292
@@ -289,7 +295,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
289295 public:
290296 HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx (const int64_t * labels,
291297 T* loss, T* log_softmax,
292- int d, int axis_dim,
298+ int64_t d, int axis_dim,
293299 int ignore_idx)
294300 : labels_(labels),
295301 loss_ (loss),
@@ -298,14 +304,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
298304 axis_dim_(axis_dim),
299305 ignore_idx_(ignore_idx) {}
300306
301- __device__ void operator ()(int idx) const {
307+ __device__ void operator ()(int64_t idx) const {
302308 // logits view as [n, axis_dim, remain], where d = axis_dim * remain
303- int remain = d_ / axis_dim_;
304- int idx_n = idx / d_;
305- int idx_axis = (idx % d_) / remain;
306- int idx_remain = idx % remain;
309+ int64_t remain = d_ / axis_dim_;
310+ int64_t idx_n = idx / d_;
311+ int64_t idx_axis = (idx % d_) / remain;
312+ int64_t idx_remain = idx % remain;
307313 // labels, loss view as [n, remain]
308- int idx_lbl = idx_n * remain + idx_remain;
314+ int64_t idx_lbl = idx_n * remain + idx_remain;
309315 if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
310316 log_softmax_[idx] = exp_on_device (log_softmax_[idx]);
311317 } else {
@@ -319,21 +325,21 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
319325 const int64_t * labels_;
320326 T* loss_;
321327 T* log_softmax_;
322- int d_;
328+ int64_t d_;
323329 int axis_dim_;
324330 int ignore_idx_;
325331};
326332
327333template <typename T>
328334static void HardLabelSoftmaxWithCrossEntropy (
329335 const platform::CUDADeviceContext& ctx, const T* logits_data,
330- const int64_t * labels_data, T* loss_data, T* softmax_data, int n, int d ,
331- int axis_dim, int ignore_idx) {
336+ const int64_t * labels_data, T* loss_data, T* softmax_data, int64_t n ,
337+ int64_t d, int axis_dim, int ignore_idx) {
332338 constexpr int kMaxBlockDim = 512 ;
333- int block_dim = axis_dim >= kMaxBlockDim
334- ? kMaxBlockDim
335- : (1 << static_cast <int >(std::log2 (axis_dim)));
336- int grid_dim = n * d / axis_dim;
339+ int64_t block_dim = axis_dim >= kMaxBlockDim
340+ ? kMaxBlockDim
341+ : (1 << static_cast <int >(std::log2 (axis_dim)));
342+ int64_t grid_dim = n * d / axis_dim;
337343 auto stream = ctx.stream ();
338344
339345#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (BlockDim ) \
@@ -372,16 +378,14 @@ static void HardLabelSoftmaxWithCrossEntropy(
372378}
373379
374380template <typename T>
375- static void SoftmaxWithCrossEntropyFusedKernel (const T* logits_data,
376- const T* labels_data,
377- T* softmax_data, T* loss_data,
378- int n, int d, int axis_dim,
379- cudaStream_t stream) {
381+ static void SoftmaxWithCrossEntropyFusedKernel (
382+ const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
383+ int64_t n, int64_t d, int axis_dim, cudaStream_t stream) {
380384 constexpr int kMaxBlockDim = 512 ;
381- int block_dim = axis_dim >= kMaxBlockDim
382- ? kMaxBlockDim
383- : (1 << static_cast <int >(std::log2 (axis_dim)));
384- int grid_dim = n * d / axis_dim;
385+ int64_t block_dim = axis_dim >= kMaxBlockDim
386+ ? kMaxBlockDim
387+ : (1 << static_cast <int >(std::log2 (axis_dim)));
388+ int64_t grid_dim = n * d / axis_dim;
385389
386390#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (BlockDim ) \
387391 case BlockDim: \
@@ -430,8 +434,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
430434 const int axis = CanonicalAxis (context.Attr <int >(" axis" ), rank);
431435 int axis_dim = logits->dims ()[axis];
432436
433- const int n = SizeToAxis (axis, logits->dims ());
434- const int d = SizeFromAxis (axis, logits->dims ());
437+ const int64_t n = SizeToAxis (axis, logits->dims ());
438+ const int64_t d = SizeFromAxis (axis, logits->dims ());
435439
436440 auto * softmax_data = softmax->mutable_data <T>(context.GetPlace ());
437441 auto * loss_data = loss->mutable_data <T>(context.GetPlace ());
@@ -500,24 +504,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
500504 const int axis = CanonicalAxis (context.Attr <int >(" axis" ), rank);
501505 int axis_dim = logit_grad->dims ()[axis];
502506
503- const int n = SizeToAxis (axis, logit_grad->dims ());
504- const int d = SizeFromAxis (axis, logit_grad->dims ());
505- const int remain = d / axis_dim;
507+ const int64_t n = SizeToAxis (axis, logit_grad->dims ());
508+ const int64_t d = SizeFromAxis (axis, logit_grad->dims ());
509+ const int64_t remain = d / axis_dim;
506510
507511 int block = 512 ;
508512 auto stream = context.cuda_device_context ().stream ();
509513 auto ignore_index = context.Attr <int >(" ignore_index" );
510514 if (context.Attr <bool >(" soft_label" )) {
511- int grid = (n * d + block - 1 ) / block;
515+ int64_t grid = (n * d + block - 1 ) / block;
512516 const T* label_data = labels->data <T>();
513517 SoftCrossEntropyGradientKernel<T><<<grid, block, 0 , stream>>> (
514518 logit_grad_data, loss_grad_data, label_data, n, d, remain);
515519 } else {
516- int grid = (n * remain + block - 1 ) / block;
520+ int64_t grid = (n * remain + block - 1 ) / block;
517521 const int64_t * label_data = labels->data <int64_t >();
518522 CrossEntropyGrad<T><<<grid, block, 0 , stream>>> (
519523 logit_grad_data, label_data, n, d, remain, ignore_index);
520- int num = n * d;
524+ int64_t num = n * d;
521525 grid = (num + block - 1 ) / block;
522526 Scale<T><<<grid, block, 0 , stream>>> (logit_grad_data, loss_grad_data, num,
523527 d, remain, label_data, ignore_index);
0 commit comments