diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 88922c3e42f1b8..7653440c5f1a1b 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1299,8 +1299,8 @@ func : inverse_grad - backward_op : kldiv_loss_grad - forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean") -> Tensor(out) - args : (Tensor x, Tensor label, Tensor out_grad, str reduction) + forward : kldiv_loss(Tensor x, Tensor label, str reduction="mean", bool log_target = false) -> Tensor(out) + args : (Tensor x, Tensor label, Tensor out_grad, str reduction, bool log_target) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 3705e2949974df..7099460e4480a3 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -326,6 +326,14 @@ comment : "The arg 'dispensable' of Input 'Scale' is changed: from 'False' to 'True'." default : "true" +- op : kldiv_loss + version : + - checkpoint : Upgrade kldiv_loss, add a new attribute [log_target] + action : + - add_attr : log_target + comment : In order to specify whether 'label' is passed in log space. + default : "false" + - op : lamb version : - checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut]. diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c5c153b425eb45..e1b7b50532a928 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1566,7 +1566,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : kldiv_loss - args : (Tensor x, Tensor label, str reduction = "mean") + args : (Tensor x, Tensor label, str reduction = "mean", bool log_target = false) output : Tensor(out) infer_meta : func : KLDivInferMeta diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 93eedea914b216..0f909f4bbafc47 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -98,6 +98,7 @@ void AllValueCompareInferMeta(const MetaTensor& x, void KLDivInferMeta(const MetaTensor& x, const MetaTensor& label, const std::string& reduction, + bool log_target, MetaTensor* out, MetaConfig config) { auto dim_x = x.dims(); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index abddee824fe8df..0908bd00712c99 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -42,6 +42,7 @@ void AllValueCompareInferMeta(const MetaTensor& x, void KLDivInferMeta(const MetaTensor& x, const MetaTensor& label, const std::string& reduction, + bool log_target, MetaTensor* out, MetaConfig config = MetaConfig()); diff --git a/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h b/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h index a5e6c3d8fbfae1..dffb3f7b108559 100644 --- a/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h @@ -23,13 +23,19 @@ namespace phi { using Array1 = Eigen::DSizes; template struct KLDivLossBackward { - HOSTDEVICE KLDivLossBackward() {} + bool log_target = false; + + HOSTDEVICE KLDivLossBackward(bool logTarget) : log_target(logTarget) {} HOSTDEVICE T operator()(const T& target, const T& grad) const { - if (target <= 0) { - return 0; + if (log_target) { + return static_cast(-1.) * std::exp(target) * grad; } else { - return static_cast(-1.) * grad; + if (target <= 0) { + return 0; + } else { + return static_cast(-1.) * target * grad; + } } } }; @@ -40,6 +46,7 @@ void KLDivLossGradKernel(const Context& dev_ctx, const DenseTensor& label, const DenseTensor& d_out, const std::string& reduction, + bool log_target, DenseTensor* d_x) { auto& place = *dev_ctx.eigen_device(); auto* target = &label; @@ -58,9 +65,9 @@ void KLDivLossGradKernel(const Context& dev_ctx, auto loss_grad_t = phi::EigenVector::Flatten(*loss_grad); auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); - auto grad_t = target_t * loss_grad_expand; + auto grad_t = loss_grad_expand; input_grad_t.device(place) = - target_t.binaryExpr(grad_t, KLDivLossBackward()); + target_t.binaryExpr(grad_t, KLDivLossBackward(log_target)); if ("mean" == reduction) { input_grad_t.device(place) = input_grad_t / static_cast(numel); diff --git a/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h b/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h index 4232e32597ed1c..6afbfe5d529786 100644 --- a/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h +++ b/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h @@ -24,21 +24,29 @@ namespace phi { using Array1 = Eigen::DSizes; template struct KLDivLossForward { - HOSTDEVICE KLDivLossForward() {} + bool log_target = false; + + HOSTDEVICE KLDivLossForward(bool logTarget) : log_target(logTarget) {} HOSTDEVICE T operator()(const T& target, const T& input) const { - if (target <= 0) { - return 0; + if (log_target) { + return std::exp(target) * (target - input); } else { - return target * (std::log(target) - input); + if (target <= 0) { + return 0; + } else { + return target * (std::log(target) - input); + } } } }; + template void KLDivLossKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& label, const std::string& reduction, + bool log_target, DenseTensor* out) { auto& place = *(dev_ctx.eigen_device()); auto* input = &x; @@ -51,7 +59,7 @@ void KLDivLossKernel(const Context& dev_ctx, auto input_t = phi::EigenVector::Flatten(*input); auto target_t = phi::EigenVector::Flatten(*target); auto loss_t = phi::EigenVector::Flatten(*loss); - auto output = target_t.binaryExpr(input_t, KLDivLossForward()); + auto output = target_t.binaryExpr(input_t, KLDivLossForward(log_target)); if ("none" == reduction) { loss_t.device(place) = output; } else if ("batchmean" == reduction) { diff --git a/paddle/phi/kernels/kldiv_loss_grad_kernel.h b/paddle/phi/kernels/kldiv_loss_grad_kernel.h index 6e05c7992eb611..0a69bbeac68be0 100644 --- a/paddle/phi/kernels/kldiv_loss_grad_kernel.h +++ b/paddle/phi/kernels/kldiv_loss_grad_kernel.h @@ -24,5 +24,6 @@ void KLDivLossGradKernel(const Context& dev_ctx, const DenseTensor& label, const DenseTensor& d_out, const std::string& reduction, + bool log_target, DenseTensor* d_x); } // namespace phi diff --git a/paddle/phi/kernels/kldiv_loss_kernel.h b/paddle/phi/kernels/kldiv_loss_kernel.h index 7c6cc231c94806..78a05c1776e392 100644 --- a/paddle/phi/kernels/kldiv_loss_kernel.h +++ b/paddle/phi/kernels/kldiv_loss_kernel.h @@ -26,5 +26,6 @@ void KLDivLossKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& label, const std::string& reduction, + bool log_target, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc b/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc index 5d2c750a4dfa33..a81653f9f7aaf4 100644 --- a/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc @@ -25,6 +25,7 @@ void KLDivLossGradKernel(const Context& dev_ctx, const DenseTensor& label, const DenseTensor& d_out, const std::string& reduction, + bool log_target, DenseTensor* d_x) { using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(d_x); @@ -33,12 +34,33 @@ void KLDivLossGradKernel(const Context& dev_ctx, } int r = XPU_SUCCESS; - r = xpu::kldiv_loss_grad(dev_ctx.x_context(), - reinterpret_cast(label.data()), - reinterpret_cast(d_out.data()), - reinterpret_cast(d_x->data()), - d_x->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad"); + + if (log_target) { + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm(label.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp); + + r = xpu::exp(dev_ctx.x_context(), + reinterpret_cast(label.data()), + label_exp, + label.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); + + r = xpu::kldiv_loss_grad(dev_ctx.x_context(), + reinterpret_cast(label_exp), + reinterpret_cast(d_out.data()), + reinterpret_cast(d_x->data()), + d_x->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad"); + } else { + r = xpu::kldiv_loss_grad(dev_ctx.x_context(), + reinterpret_cast(label.data()), + reinterpret_cast(d_out.data()), + reinterpret_cast(d_x->data()), + d_x->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad"); + } + if ("none" != reduction) { PADDLE_THROW(phi::errors::Unavailable( "Not supported reduction [%s] in kldiv_loss_grad", reduction)); diff --git a/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc b/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc index 4ef917f008ab9e..2351d02cf4d1f3 100644 --- a/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc +++ b/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc @@ -24,6 +24,7 @@ void KLDivLossKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& label, const std::string& reduction, + bool log_target, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); @@ -32,12 +33,33 @@ void KLDivLossKernel(const Context& dev_ctx, } int r = XPU_SUCCESS; - r = xpu::kldiv_loss(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(label.data()), - reinterpret_cast(out->data()), - out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss"); + + if (log_target) { + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + XPUType* label_exp = RAII_GUARD.alloc_l3_or_gm(label.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(label_exp); + + r = xpu::exp(dev_ctx.x_context(), + reinterpret_cast(label.data()), + label_exp, + label.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); + + r = xpu::kldiv_loss(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(label_exp), + reinterpret_cast(out->data()), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss"); + } else { + r = xpu::kldiv_loss(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(label.data()), + reinterpret_cast(out->data()), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss"); + } + if ("none" != reduction) { PADDLE_THROW(phi::errors::Unavailable( "Not supported reduction [%s] in kldiv_loss", reduction)); diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 3a44c20ace6fd6..2fe02ef0a2259d 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1618,7 +1618,7 @@ def poisson_nll_loss( return loss_out -def kl_div(input, label, reduction='mean', name=None): +def kl_div(input, label, reduction='mean', log_target=False, name=None): r""" Calculate the Kullback-Leibler divergence loss between Input(X) and Input(Target). Notes that Input(X) is the @@ -1626,8 +1626,14 @@ def kl_div(input, label, reduction='mean', name=None): KL divergence loss is calculated as follows: + If `log_target` is False: + $$l(x, y) = y * (\log(y) - x)$$ + If `log_target` is True: + + $$l(x, y) = \exp(y) * (y - x)$$ + Here :math:`x` is input and :math:`y` is label. If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result. @@ -1649,6 +1655,7 @@ def kl_div(input, label, reduction='mean', name=None): if `reduction` is ``'sum'``, the reduced sum loss is returned; if `reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``. + log_target (bool, optional): Indicate whether `label` is passed in log space. Default is False. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1689,6 +1696,15 @@ def kl_div(input, label, reduction='mean', name=None): >>> print(pred_loss.shape) [5, 20] + >>> # if label is in the log space, set log_target = True + >>> target = paddle.uniform(shape, min=0, max=10).astype('float32') + >>> log_target = paddle.log(target) + >>> pred_loss_1 = F.kl_div(x, target, reduction='none') + >>> pred_loss_2 = F.kl_div(x, log_target, reduction='none', log_target=True) + >>> print(paddle.allclose(pred_loss_1, pred_loss_2)) + Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True, + True) + """ # ugly type promotion if ( @@ -1703,7 +1719,7 @@ def kl_div(input, label, reduction='mean', name=None): label = paddle.cast(label, 'float64') if in_dynamic_or_pir_mode(): - out = _C_ops.kldiv_loss(input, label, 'none') + out = _C_ops.kldiv_loss(input, label, 'none', log_target) if reduction == 'mean': out = paddle.mean(out) elif reduction == 'sum': @@ -1729,7 +1745,7 @@ def kl_div(input, label, reduction='mean', name=None): type='kldiv_loss', inputs={'X': input, 'Target': label}, outputs={'Loss': loss}, - attrs={'reduction': 'none'}, + attrs={'reduction': 'none', 'log_target': log_target}, ) if reduction == 'mean': diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 1fd2501698c2f2..087dc10a58e586 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1034,8 +1034,14 @@ class KLDivLoss(Layer): KL divergence loss is calculated as follows: + If `log_target` is False: + $$l(x, y) = y * (\log(y) - x)$$ + If `log_target` is True: + + $$l(x, y) = \exp(y) * (y - x)$$ + Here :math:`x` is input and :math:`y` is label. If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result. @@ -1054,6 +1060,7 @@ class KLDivLoss(Layer): if `reduction` is ``'sum'``, the reduced sum loss is returned; if `reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``. + log_target (bool, optional): Indicate whether `label` is passed in log space. Default is False. Shape: @@ -1097,14 +1104,25 @@ class KLDivLoss(Layer): >>> print(pred_loss.shape) [5, 20] + >>> # if label is in the log space, set log_target = True + >>> target = paddle.uniform(shape, min=0, max=10).astype('float32') + >>> log_target = paddle.log(target) + >>> kldiv_criterion_1 = nn.KLDivLoss(reduction='none') + >>> kldiv_criterion_2 = nn.KLDivLoss(reduction='none', log_target=True) + >>> pred_loss_1 = kldiv_criterion_1(x, target) + >>> pred_loss_2 = kldiv_criterion_2(x, log_target) + >>> print(paddle.allclose(pred_loss_1, pred_loss_2)) + Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True, + True) """ - def __init__(self, reduction='mean'): + def __init__(self, reduction='mean', log_target=False): super().__init__() self.reduction = reduction + self.log_target = log_target def forward(self, input, label): - out = F.kl_div(input, label, self.reduction) + out = F.kl_div(input, label, self.reduction, self.log_target) return out diff --git a/test/deprecated/legacy_test/test_kldiv_loss_op.py b/test/deprecated/legacy_test/test_kldiv_loss_op.py index 599b9764c984d9..dbf421ba263e49 100644 --- a/test/deprecated/legacy_test/test_kldiv_loss_op.py +++ b/test/deprecated/legacy_test/test_kldiv_loss_op.py @@ -21,9 +21,12 @@ from paddle.pir_utils import test_with_pir_api -def kldiv_loss(x, target, reduction): - output = target * (np.log(target) - x) - loss = np.where(target >= 0, output, np.zeros_like(x)) +def kldiv_loss(x, target, reduction, log_target=False): + if log_target: + loss = np.exp(target) * (target - x) + else: + output = target * (np.log(target) - x) + loss = np.where(target >= 0, output, np.zeros_like(x)) if reduction == "batchmean": if len(x.shape) > 0: @@ -46,13 +49,16 @@ def setUp(self): x = np.random.uniform(-10, 10, self.x_shape).astype('float64') target = np.random.uniform(-10, 10, self.x_shape).astype('float64') - self.attrs = {"reduction": self.reduction} + self.attrs = { + "reduction": self.reduction, + "log_target": self.log_target, + } self.inputs = { 'X': x, 'Target': target, } - loss = kldiv_loss(x, target, self.reduction) + loss = kldiv_loss(x, target, self.reduction, self.log_target) self.outputs = {'Loss': loss.astype('float64')} def test_check_output(self): @@ -64,34 +70,52 @@ def test_check_grad(self): def initTestCase(self): self.x_shape = (4, 5, 5) self.reduction = 'batchmean' + self.log_target = False class TestKLDivLossOp2(TestKLDivLossOp): def initTestCase(self): self.x_shape = (3, 2, 7, 7) self.reduction = 'none' + self.log_target = False class TestKLDivLossOp3(TestKLDivLossOp): def initTestCase(self): self.x_shape = (2, 3, 5, 7, 9) self.reduction = 'mean' + self.log_target = False class TestKLDivLossOp4(TestKLDivLossOp): def initTestCase(self): self.x_shape = (5, 20) self.reduction = 'sum' + self.log_target = False + + +class TestKLDivLossOp5(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (5, 20) + self.reduction = 'sum' + self.log_target = True + + +class TestKLDivLossOp6(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (3, 2, 7, 7) + self.reduction = 'none' + self.log_target = True class TestKLDivLossDygraph(unittest.TestCase): - def run_kl_loss(self, reduction, shape=(5, 20)): + def run_kl_loss(self, reduction, shape=(5, 20), log_target=False): x = np.random.uniform(-10, 10, shape).astype('float64') target = np.random.uniform(-10, 10, shape).astype('float64') - gt_loss = kldiv_loss(x, target, reduction) + gt_loss = kldiv_loss(x, target, reduction, log_target) with paddle.base.dygraph.guard(): - kldiv_criterion = paddle.nn.KLDivLoss(reduction) + kldiv_criterion = paddle.nn.KLDivLoss(reduction, log_target) pred_loss = kldiv_criterion( paddle.to_tensor(x), paddle.to_tensor(target) ) @@ -112,6 +136,9 @@ def test_kl_loss_sum(self): def test_kl_loss_none(self): self.run_kl_loss('none') + def test_kl_loss_mean_logtarget(self): + self.run_kl_loss('mean', log_target=True) + @test_with_pir_api def test_kl_loss_static_api(self): with paddle_static_guard(): @@ -121,6 +148,7 @@ def test_kl_loss_static_api(self): paddle.nn.functional.kl_div(input, label) paddle.nn.functional.kl_div(input, label, 'sum') paddle.nn.functional.kl_div(input, label, 'batchmean') + paddle.nn.functional.kl_div(input, label, 'batchmean', True) class TestKLDivLossTypePromotion(unittest.TestCase): diff --git a/test/xpu/test_kldiv_loss_op_xpu.py b/test/xpu/test_kldiv_loss_op_xpu.py index 28799909162876..e119adedbf0e68 100644 --- a/test/xpu/test_kldiv_loss_op_xpu.py +++ b/test/xpu/test_kldiv_loss_op_xpu.py @@ -27,9 +27,12 @@ paddle.enable_static() -def kldiv_loss(x, target, reduction): - output = target * (np.log(target) - x) - loss = np.where(target >= 0, output, np.zeros_like(x)) +def kldiv_loss(x, target, reduction, log_target=False): + if log_target: + loss = np.exp(target) * (target - x) + else: + output = target * (np.log(target) - x) + loss = np.where(target >= 0, output, np.zeros_like(x)) if reduction == "batchmean": if len(x.shape) > 0: @@ -59,13 +62,16 @@ def setUp(self): x = np.random.uniform(-10, 10, self.x_shape).astype('float32') target = np.random.uniform(-10, 10, self.x_shape).astype('float32') - self.attrs = {"reduction": self.reduction} + self.attrs = { + "reduction": self.reduction, + "log_target": self.log_target, + } self.inputs = { 'X': x, 'Target': target, } - loss = kldiv_loss(x, target, self.reduction) + loss = kldiv_loss(x, target, self.reduction, self.log_target) self.outputs = {'Loss': loss.astype('float32')} def test_check_output(self): @@ -83,30 +89,46 @@ def test_check_grad(self): def initTestCase(self): self.x_shape = (4, 5, 5) self.reduction = 'none' + self.log_target = False class TestKLDivLossOp2(TestKLDivLossOp): def initTestCase(self): self.x_shape = (3, 2, 7, 7) self.reduction = 'none' + self.log_target = False class TestKLDivLossOp3(TestKLDivLossOp): def initTestCase(self): self.x_shape = (2, 3, 5, 7, 9) self.reduction = 'none' + self.log_target = False class TestKLDivLossOp4(TestKLDivLossOp): def initTestCase(self): self.x_shape = (5, 20) self.reduction = 'none' + self.log_target = False + + class TestKLDivLossOp5(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (5, 20) + self.reduction = 'none' + self.log_target = True + + class TestKLDivLossOp6(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (3, 2, 7, 7) + self.reduction = 'none' + self.log_target = True class TestKLDivLossDygraph(unittest.TestCase): - def run_kl_loss(self, reduction, shape=(5, 20)): + def run_kl_loss(self, reduction, shape=(5, 20), log_target=False): x = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32') - gt_loss = kldiv_loss(x, target, reduction) + gt_loss = kldiv_loss(x, target, reduction, log_target) with paddle.base.dygraph.guard(): - kldiv_criterion = paddle.nn.KLDivLoss(reduction) + kldiv_criterion = paddle.nn.KLDivLoss(reduction, log_target) pred_loss = kldiv_criterion( paddle.to_tensor(x), paddle.to_tensor(target) ) @@ -117,11 +139,15 @@ def run_kl_loss(self, reduction, shape=(5, 20)): def test_kl_loss_none(self): self.run_kl_loss('none') + def test_kl_loss_mean_logtarget(self): + self.run_kl_loss('none', log_target=True) + def test_kl_loss_static_api(self): input = paddle.static.data(name='input', shape=[5, 20]) label = paddle.static.data(name='label', shape=[5, 20]) paddle.nn.functional.kl_div(input, label) + paddle.nn.functional.kl_div(input, label, 'none', True) class TestKLDivLossTypePromotion(unittest.TestCase): def test_kl_div_promotion(self):