@@ -25,6 +25,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
2525 const DenseTensor& label,
2626 const DenseTensor& d_out,
2727 const std::string& reduction,
28+ bool log_target,
2829 DenseTensor* d_x) {
2930 using XPUType = typename XPUTypeTrait<T>::Type;
3031 dev_ctx.template Alloc <T>(d_x);
@@ -33,12 +34,33 @@ void KLDivLossGradKernel(const Context& dev_ctx,
3334 }
3435
3536 int r = XPU_SUCCESS;
36- r = xpu::kldiv_loss_grad (dev_ctx.x_context (),
37- reinterpret_cast <const XPUType*>(label.data <T>()),
38- reinterpret_cast <const XPUType*>(d_out.data <T>()),
39- reinterpret_cast <XPUType*>(d_x->data <T>()),
40- d_x->numel ());
41- PADDLE_ENFORCE_XDNN_SUCCESS (r, " kldiv_loss_grad" );
37+
38+ if (log_target) {
39+ xpu::ctx_guard RAII_GUARD (dev_ctx.x_context ());
40+ XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm <XPUType>(label.numel ());
41+ PADDLE_ENFORCE_XDNN_NOT_NULL (label_exp);
42+
43+ r = xpu::exp (dev_ctx.x_context (),
44+ reinterpret_cast <const XPUType*>(label.data <T>()),
45+ label_exp,
46+ label.numel ());
47+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " exp" );
48+
49+ r = xpu::kldiv_loss_grad (dev_ctx.x_context (),
50+ reinterpret_cast <const XPUType*>(label_exp),
51+ reinterpret_cast <const XPUType*>(d_out.data <T>()),
52+ reinterpret_cast <XPUType*>(d_x->data <T>()),
53+ d_x->numel ());
54+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " kldiv_loss_grad" );
55+ } else {
56+ r = xpu::kldiv_loss_grad (dev_ctx.x_context (),
57+ reinterpret_cast <const XPUType*>(label.data <T>()),
58+ reinterpret_cast <const XPUType*>(d_out.data <T>()),
59+ reinterpret_cast <XPUType*>(d_x->data <T>()),
60+ d_x->numel ());
61+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " kldiv_loss_grad" );
62+ }
63+
4264 if (" none" != reduction) {
4365 PADDLE_THROW (phi::errors::Unavailable (
4466 " Not supported reduction [%s] in kldiv_loss_grad" , reduction));
0 commit comments