Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,8 @@
func : inverse_grad

- backward_op : kldiv_loss_grad
forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean") -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, str reduction)
forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean", bool log_target = false) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, str reduction, bool log_target)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/op_version.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,14 @@
comment : "The arg 'dispensable' of Input 'Scale' is changed: from 'False' to 'True'."
default : "true"

- op : kldiv_loss
version :
- checkpoint : Upgrade kldiv_loss, add a new attribute [log_target]
action :
- add_attr : log_target
comment : In order to specify whether 'label' is passed in log space.
default : "false"

- op : lamb
version :
- checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut].
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : kldiv_loss
args : (Tensor x, Tensor label, str reduction = "mean")
args : (Tensor x, Tensor label, str reduction = "mean", bool log_target = false)
output : Tensor(out)
infer_meta :
func : KLDivInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
bool log_target,
MetaTensor* out,
MetaConfig config) {
auto dim_x = x.dims();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
bool log_target,
MetaTensor* out,
MetaConfig config = MetaConfig());

Expand Down
19 changes: 13 additions & 6 deletions paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@ namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossBackward {
HOSTDEVICE KLDivLossBackward() {}
bool log_target = false;

HOSTDEVICE KLDivLossBackward(bool logTarget) : log_target(logTarget) {}

HOSTDEVICE T operator()(const T& target, const T& grad) const {
if (target <= 0) {
return 0;
if (log_target) {
return static_cast<T>(-1.) * std::exp(target) * grad;
} else {
return static_cast<T>(-1.) * grad;
if (target <= 0) {
return 0;
} else {
return static_cast<T>(-1.) * target * grad;
}
}
}
};
Expand All @@ -40,6 +46,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
bool log_target,
DenseTensor* d_x) {
auto& place = *dev_ctx.eigen_device();
auto* target = &label;
Expand All @@ -58,9 +65,9 @@ void KLDivLossGradKernel(const Context& dev_ctx,
auto loss_grad_t = phi::EigenVector<T>::Flatten(*loss_grad);

auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
auto grad_t = target_t * loss_grad_expand;
auto grad_t = loss_grad_expand;
input_grad_t.device(place) =
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
target_t.binaryExpr(grad_t, KLDivLossBackward<T>(log_target));

if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
Expand Down
18 changes: 13 additions & 5 deletions paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,29 @@ namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossForward {
HOSTDEVICE KLDivLossForward() {}
bool log_target = false;

HOSTDEVICE KLDivLossForward(bool logTarget) : log_target(logTarget) {}

HOSTDEVICE T operator()(const T& target, const T& input) const {
if (target <= 0) {
return 0;
if (log_target) {
return std::exp(target) * (target - input);
} else {
return target * (std::log(target) - input);
if (target <= 0) {
return 0;
} else {
return target * (std::log(target) - input);
}
}
}
};

template <typename T, typename Context>
void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
bool log_target,
DenseTensor* out) {
auto& place = *(dev_ctx.eigen_device());
auto* input = &x;
Expand All @@ -51,7 +59,7 @@ void KLDivLossKernel(const Context& dev_ctx,
auto input_t = phi::EigenVector<T>::Flatten(*input);
auto target_t = phi::EigenVector<T>::Flatten(*target);
auto loss_t = phi::EigenVector<T>::Flatten(*loss);
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>(log_target));
if ("none" == reduction) {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/kldiv_loss_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
bool log_target,
DenseTensor* d_x);
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/kldiv_loss_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
bool log_target,
DenseTensor* out);
} // namespace phi
34 changes: 28 additions & 6 deletions paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void KLDivLossGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& d_out,
const std::string& reduction,
bool log_target,
DenseTensor* d_x) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(d_x);
Expand All @@ -33,12 +34,33 @@ void KLDivLossGradKernel(const Context& dev_ctx,
}

int r = XPU_SUCCESS;
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<const XPUType*>(d_out.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_x->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");

if (log_target) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm<XPUType>(label.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp);

r = xpu::exp(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
label_exp,
label.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");

r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label_exp),
reinterpret_cast<const XPUType*>(d_out.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_x->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
} else {
r = xpu::kldiv_loss_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<const XPUType*>(d_out.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_x->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad");
}

if ("none" != reduction) {
PADDLE_THROW(phi::errors::Unavailable(
"Not supported reduction [%s] in kldiv_loss_grad", reduction));
Expand Down
34 changes: 28 additions & 6 deletions paddle/phi/kernels/xpu/kldiv_loss_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void KLDivLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const std::string& reduction,
bool log_target,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
Expand All @@ -32,12 +33,33 @@ void KLDivLossKernel(const Context& dev_ctx,
}

int r = XPU_SUCCESS;
r = xpu::kldiv_loss(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");

if (log_target) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm<XPUType>(label.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp);

r = xpu::exp(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(label.data<T>()),
label_exp,
label.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");

r = xpu::kldiv_loss(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(label_exp),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");
} else {
r = xpu::kldiv_loss(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(label.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss");
}

if ("none" != reduction) {
PADDLE_THROW(phi::errors::Unavailable(
"Not supported reduction [%s] in kldiv_loss", reduction));
Expand Down
21 changes: 18 additions & 3 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,16 +1618,22 @@ def poisson_nll_loss(
return loss_out


def kl_div(input, label, reduction='mean', name=None):
def kl_div(input, label, reduction='mean', log_target=False, name=None):
r"""
Calculate the Kullback-Leibler divergence loss
between Input(X) and Input(Target). Notes that Input(X) is the
log-probability and Input(Target) is the probability.

KL divergence loss is calculated as follows:

If `log_target` is False:

$$l(x, y) = y * (\log(y) - x)$$

If `log_target` is True:

$$l(x, y) = \exp(y) * (y - x)$$

Here :math:`x` is input and :math:`y` is label.

If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result.
Expand All @@ -1649,6 +1655,7 @@ def kl_div(input, label, reduction='mean', name=None):
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be applied.
Default is ``'mean'``.
log_target (bool, optional): Indicate whether `label` is passed in log space. Default is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的示例代码,可以加一个log_target=True时的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

name(str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -1689,6 +1696,14 @@ def kl_div(input, label, reduction='mean', name=None):
>>> print(pred_loss.shape)
[5, 20]

>>> # if label is in the log space, set log_target = True
>>> log_target = paddle.log(target)
>>> pred_loss_1 = F.kl_div(x, target, reduction='none')
>>> pred_loss_2 = F.kl_div(x, log_target, reduction='none', log_target=True)
>>> print(paddle.equal_all(pred_loss_1, pred_loss_2))
Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True,
True)

"""
# ugly type promotion
if (
Expand All @@ -1703,7 +1718,7 @@ def kl_div(input, label, reduction='mean', name=None):
label = paddle.cast(label, 'float64')

if in_dynamic_or_pir_mode():
out = _C_ops.kldiv_loss(input, label, 'none')
out = _C_ops.kldiv_loss(input, label, 'none', log_target)
if reduction == 'mean':
out = paddle.mean(out)
elif reduction == 'sum':
Expand All @@ -1729,7 +1744,7 @@ def kl_div(input, label, reduction='mean', name=None):
type='kldiv_loss',
inputs={'X': input, 'Target': label},
outputs={'Loss': loss},
attrs={'reduction': 'none'},
attrs={'reduction': 'none', 'log_target': log_target},
)

if reduction == 'mean':
Expand Down
21 changes: 19 additions & 2 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,14 @@ class KLDivLoss(Layer):

KL divergence loss is calculated as follows:

If `log_target` is False:

$$l(x, y) = y * (\log(y) - x)$$

If `log_target` is True:

$$l(x, y) = \exp(y) * (y - x)$$

Here :math:`x` is input and :math:`y` is label.

If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result.
Expand All @@ -1054,6 +1060,7 @@ class KLDivLoss(Layer):
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be applied.
Default is ``'mean'``.
log_target (bool, optional): Indicate whether `label` is passed in log space. Default is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的示例代码,可以加一个log_target=True时的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


Shape:

Expand Down Expand Up @@ -1097,14 +1104,24 @@ class KLDivLoss(Layer):
>>> print(pred_loss.shape)
[5, 20]

>>> # if label is in the log space, set log_target = True
>>> log_target = paddle.log(target)
>>> kldiv_criterion_1 = nn.KLDivLoss(reduction='none')
>>> kldiv_criterion_2 = nn.KLDivLoss(reduction='none', log_target=True)
>>> pred_loss_1 = kldiv_criterion_1(x, target)
>>> pred_loss_2 = kldiv_criterion_2(x, log_target)
>>> print(paddle.equal_all(pred_loss_1, pred_loss_2))
Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True,
True)
"""

def __init__(self, reduction='mean'):
def __init__(self, reduction='mean', log_target=False):
super().__init__()
self.reduction = reduction
self.log_target = log_target

def forward(self, input, label):
out = F.kl_div(input, label, self.reduction)
out = F.kl_div(input, label, self.reduction, self.log_target)
return out


Expand Down
Loading