Skip to content

Commit e0ba1bf

Browse files
NKNaNco63oc
authored andcommitted
【Hackathon 6th No.29】为 paddle.nn.functional.kl_div 进行功能增强 (PaddlePaddle#63860)
* udpate kldiv_loss * update xpu * fix xpu test * add docs code example and add test case * fix code example * fix code example * fix code example
1 parent dbec70d commit e0ba1bf

15 files changed

Lines changed: 206 additions & 47 deletions

File tree

paddle/phi/api/yaml/backward.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,8 +1299,8 @@
12991299
func : inverse_grad
13001300

13011301
- backward_op : kldiv_loss_grad
1302-
forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean") -> Tensor(out)
1303-
args : (Tensor x, Tensor label, Tensor out_grad, str reduction)
1302+
forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean", bool log_target = false) -> Tensor(out)
1303+
args : (Tensor x, Tensor label, Tensor out_grad, str reduction, bool log_target)
13041304
output : Tensor(x_grad)
13051305
infer_meta :
13061306
func : UnchangedInferMeta

paddle/phi/api/yaml/op_version.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,14 @@
326326
comment : "The arg 'dispensable' of Input 'Scale' is changed: from 'False' to 'True'."
327327
default : "true"
328328

329+
- op : kldiv_loss
330+
version :
331+
- checkpoint : Upgrade kldiv_loss, add a new attribute [log_target]
332+
action :
333+
- add_attr : log_target
334+
comment : In order to specify whether 'label' is passed in log space.
335+
default : "false"
336+
329337
- op : lamb
330338
version :
331339
- checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut].

paddle/phi/api/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@
15871587
interfaces : paddle::dialect::InferSymbolicShapeInterface
15881588

15891589
- op : kldiv_loss
1590-
args : (Tensor x, Tensor label, str reduction = "mean")
1590+
args : (Tensor x, Tensor label, str reduction = "mean", bool log_target = false)
15911591
output : Tensor(out)
15921592
infer_meta :
15931593
func : KLDivInferMeta

paddle/phi/infermeta/binary.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
9898
void KLDivInferMeta(const MetaTensor& x,
9999
const MetaTensor& label,
100100
const std::string& reduction,
101+
bool log_target,
101102
MetaTensor* out,
102103
MetaConfig config) {
103104
auto dim_x = x.dims();

paddle/phi/infermeta/binary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
4242
void KLDivInferMeta(const MetaTensor& x,
4343
const MetaTensor& label,
4444
const std::string& reduction,
45+
bool log_target,
4546
MetaTensor* out,
4647
MetaConfig config = MetaConfig());
4748

paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,19 @@ namespace phi {
2323
using Array1 = Eigen::DSizes<int64_t, 1>;
2424
template <typename T>
2525
struct KLDivLossBackward {
26-
HOSTDEVICE KLDivLossBackward() {}
26+
bool log_target = false;
27+
28+
HOSTDEVICE KLDivLossBackward(bool logTarget) : log_target(logTarget) {}
2729

2830
HOSTDEVICE T operator()(const T& target, const T& grad) const {
29-
if (target <= 0) {
30-
return 0;
31+
if (log_target) {
32+
return static_cast<T>(-1.) * std::exp(target) * grad;
3133
} else {
32-
return static_cast<T>(-1.) * grad;
34+
if (target <= 0) {
35+
return 0;
36+
} else {
37+
return static_cast<T>(-1.) * target * grad;
38+
}
3339
}
3440
}
3541
};
@@ -40,6 +46,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
4046
const DenseTensor& label,
4147
const DenseTensor& d_out,
4248
const std::string& reduction,
49+
bool log_target,
4350
DenseTensor* d_x) {
4451
auto& place = *dev_ctx.eigen_device();
4552
auto* target = &label;
@@ -58,9 +65,9 @@ void KLDivLossGradKernel(const Context& dev_ctx,
5865
auto loss_grad_t = phi::EigenVector<T>::Flatten(*loss_grad);
5966

6067
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
61-
auto grad_t = target_t * loss_grad_expand;
68+
auto grad_t = loss_grad_expand;
6269
input_grad_t.device(place) =
63-
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
70+
target_t.binaryExpr(grad_t, KLDivLossBackward<T>(log_target));
6471

6572
if ("mean" == reduction) {
6673
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);

paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,29 @@ namespace phi {
2424
using Array1 = Eigen::DSizes<int64_t, 1>;
2525
template <typename T>
2626
struct KLDivLossForward {
27-
HOSTDEVICE KLDivLossForward() {}
27+
bool log_target = false;
28+
29+
HOSTDEVICE KLDivLossForward(bool logTarget) : log_target(logTarget) {}
2830

2931
HOSTDEVICE T operator()(const T& target, const T& input) const {
30-
if (target <= 0) {
31-
return 0;
32+
if (log_target) {
33+
return std::exp(target) * (target - input);
3234
} else {
33-
return target * (std::log(target) - input);
35+
if (target <= 0) {
36+
return 0;
37+
} else {
38+
return target * (std::log(target) - input);
39+
}
3440
}
3541
}
3642
};
43+
3744
template <typename T, typename Context>
3845
void KLDivLossKernel(const Context& dev_ctx,
3946
const DenseTensor& x,
4047
const DenseTensor& label,
4148
const std::string& reduction,
49+
bool log_target,
4250
DenseTensor* out) {
4351
auto& place = *(dev_ctx.eigen_device());
4452
auto* input = &x;
@@ -51,7 +59,7 @@ void KLDivLossKernel(const Context& dev_ctx,
5159
auto input_t = phi::EigenVector<T>::Flatten(*input);
5260
auto target_t = phi::EigenVector<T>::Flatten(*target);
5361
auto loss_t = phi::EigenVector<T>::Flatten(*loss);
54-
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
62+
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>(log_target));
5563
if ("none" == reduction) {
5664
loss_t.device(place) = output;
5765
} else if ("batchmean" == reduction) {

paddle/phi/kernels/kldiv_loss_grad_kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ void KLDivLossGradKernel(const Context& dev_ctx,
2424
const DenseTensor& label,
2525
const DenseTensor& d_out,
2626
const std::string& reduction,
27+
bool log_target,
2728
DenseTensor* d_x);
2829
} // namespace phi

paddle/phi/kernels/kldiv_loss_kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@ void KLDivLossKernel(const Context& dev_ctx,
2626
const DenseTensor& x,
2727
const DenseTensor& label,
2828
const std::string& reduction,
29+
bool log_target,
2930
DenseTensor* out);
3031
} // namespace phi

paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
2525
const DenseTensor& label,
2626
const DenseTensor& d_out,
2727
const std::string& reduction,
28+
bool log_target,
2829
DenseTensor* d_x) {
2930
using XPUType = typename XPUTypeTrait<T>::Type;
3031
dev_ctx.template Alloc<T>(d_x);
@@ -33,12 +34,33 @@ void KLDivLossGradKernel(const Context& dev_ctx,
3334
}
3435

3536
int r = XPU_SUCCESS;
36-
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
37-
reinterpret_cast<const XPUType*>(label.data<T>()),
38-
reinterpret_cast<const XPUType*>(d_out.data<T>()),
39-
reinterpret_cast<XPUType*>(d_x->data<T>()),
40-
d_x->numel());
41-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
37+
38+
if (log_target) {
39+
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
40+
XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm<XPUType>(label.numel());
41+
PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp);
42+
43+
r = xpu::exp(dev_ctx.x_context(),
44+
reinterpret_cast<const XPUType*>(label.data<T>()),
45+
label_exp,
46+
label.numel());
47+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");
48+
49+
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
50+
reinterpret_cast<const XPUType*>(label_exp),
51+
reinterpret_cast<const XPUType*>(d_out.data<T>()),
52+
reinterpret_cast<XPUType*>(d_x->data<T>()),
53+
d_x->numel());
54+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
55+
} else {
56+
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
57+
reinterpret_cast<const XPUType*>(label.data<T>()),
58+
reinterpret_cast<const XPUType*>(d_out.data<T>()),
59+
reinterpret_cast<XPUType*>(d_x->data<T>()),
60+
d_x->numel());
61+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
62+
}
63+
4264
if ("none" != reduction) {
4365
PADDLE_THROW(phi::errors::Unavailable(
4466
"Not supported reduction [%s] in kldiv_loss_grad", reduction));

0 commit comments

Comments
 (0)