Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 9 additions & 26 deletions paddle/phi/kernels/gpu/cross_entropy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -713,14 +713,16 @@ template <typename T>
static void SoftmaxWithCrossEntropySoftLabel(const GPUContext& dev_ctx,
const int rank,
const int axis,
const T* logits_data,
const DenseTensor& logits,
const T* labels_data,
T* softmax_data,
DenseTensor* softmax,
T* loss_data,
int N,
int dim,
int D) {
constexpr int kMaxBlockDim = 512;
auto* logits_data = logits.data<T>();
auto* softmax_data = softmax->data<T>();
int64_t block_dim = dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(dim)));
Expand Down Expand Up @@ -762,13 +764,7 @@ static void SoftmaxWithCrossEntropySoftLabel(const GPUContext& dev_ctx,
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif

auto handle = dev_ctx.cudnn_handle();

#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxForward_V2(
Expand All @@ -782,18 +778,8 @@ static void SoftmaxWithCrossEntropySoftLabel(const GPUContext& dev_ctx,
MIOPEN_SOFTMAX_LOG,
mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxForward(
handle,
CUDNN_SOFTMAX_LOG,
mode,
phi::backends::gpu::CudnnDataType<T>::kOne(),
descp,
logits_data,
phi::backends::gpu::CudnnDataType<T>::kZero(),
descp,
softmax_data));
SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, logits, axis, softmax);
softmax_data = softmax->data<T>();
#endif

const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
Expand Down Expand Up @@ -1170,7 +1156,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
VLOG(7) << "rank=" << rank << ", axis = " << axis << ", N = " << N
<< ", dim = " << dim << ", D = " << D;
auto* logits_data = logits.data<T>();
auto* softmax_data = dev_ctx.template Alloc<T>(softmax);
auto* softmax_data = softmax->data<T>();
auto stream = dev_ctx.stream();
constexpr int max_dim = 320;
if (D == 1) {
Expand Down Expand Up @@ -1216,8 +1202,6 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
MIOPEN_SOFTMAX_LOG,
mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, logits, axis, softmax);
softmax_data = softmax->data<T>();
#endif
Expand Down Expand Up @@ -1352,14 +1336,13 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
}

if (soft_label) {
auto* logits_data = logits.data<T>();
auto* labels_data = label.data<T>();
SoftmaxWithCrossEntropySoftLabel<T>(dev_ctx,
rank,
axis_v,
logits_data,
logits,
labels_data,
softmax_data,
softmax,
loss_data,
n,
axis_dim,
Expand Down
Loading