Skip to content

Commit 91dfddf

Browse files
committed
Made some minor changes
1 parent c260064 commit 91dfddf

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

paddle/fluid/operators/log_softmax_op.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
388388
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
389389
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;
390390

391-
int thread_in_warp_idx = threadIdx.x % kernel_warp_size;
391+
int thread_in_warp_idx = threadIdx.x;
392392

393393
// 1.read data from global memory to registers
394394
AccT output_register[warp_iter];
@@ -464,13 +464,13 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
464464
public:
465465
void Compute(const framework::ExecutionContext &context) const override {
466466
const auto *out = context.Input<framework::Tensor>("Out");
467-
const auto *g_out =
467+
const auto *d_out =
468468
context.Input<framework::Tensor>(framework::GradVarName("Out"));
469-
auto *g_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
469+
auto *d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
470470

471471
const auto *out_data = out->data<T>();
472-
const auto *g_out_data = g_out->data<T>();
473-
auto *g_x_data = g_x->mutable_data<T>(context.GetPlace());
472+
const auto *d_out_data = d_out->data<T>();
473+
auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
474474

475475
const int rank = out->dims().size();
476476
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
@@ -485,11 +485,11 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
485485

486486
if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
487487
LaunchSoftmaxBackwardForLastAxis<T, MPDType>(
488-
g_x_data, g_out_data, out_data, dim_size, outer_size, stream);
488+
d_x_data, d_out_data, out_data, dim_size, outer_size, stream);
489489
} else {
490490
LogSoftmaxGradFunctor<platform::CUDADeviceContext, T>()(
491491
context.template device_context<platform::CUDADeviceContext>(), out,
492-
g_out, g_x, axis);
492+
d_out, d_x, axis);
493493
}
494494
}
495495
};

0 commit comments

Comments
 (0)