Skip to content
114 changes: 114 additions & 0 deletions paddle/operators/margin_rank_loss_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/margin_rank_loss_op.h"

namespace paddle {
namespace operators {

class MarginRankLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
// input check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should here also check the output Var "Out" is not null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
auto x1_dims = ctx.Input<framework::Tensor>("X1")->dims();
auto x2_dims = ctx.Input<framework::Tensor>("X2")->dims();
PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) &&
(label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be vector with the same size");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • "All inputs must be a vector with the same size."
  • If the comment is a complete sentence, please add the commas at the end of the sentence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ctx.Output<framework::LoDTensor>("Activated")->Resize(label_dims);
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure Activated and Out are not nullptr before resize them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@lcy-seso lcy-seso Sep 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should here add the following codes? I am not sure, because for this operator the input X1, X2, and the output are always non-sequence. In this case, are the codes below still necessary? @qingqing01

ctx.ShareLoD("X1", /*->*/ "Out");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not necessary here

}
};

template <typename AttrType>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MarginRankLossOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1", "The first variable to be ranked, row vector.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are row vectors? not column vectors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A mistake, corrected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddInput("X2", "The second variable to be ranked, row vector.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refine the comments as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddInput("Label",
"The label indicating X1 ranked higher than X2 "
"or not, row vector.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • a row vector. I think a better way is like this: A 2-D tensor with shape [N x 1]. (N has already been explained above in X1.)
  • Please do not forget the article.
  • Please add NOTE: the label can only be +1 or -1. (If I understand right.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddAttr<AttrType>("margin", "Margin for MarginRankLossOp, scalar.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In class MarginRankLossKernel, we can see that AttrType should be consistent with T. So maybe using T directly is a better practice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

.SetDefault(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 should be const_cast to AttrType first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddOutput("Activated",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the doc by following: (type, default value) usage style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"Intermediate tensor to indicate whether each element of "
"Output(Out) is activated.")
.AsIntermediate();
AddOutput("Out", "The output loss of MarginRankLoss operator");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the doc by following: (type, default value) usage style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddComment(R"DOC(

MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`}
and the `Label` with attribute `margin`, where `Label = 1` indicating X1 is
ranked higher than `X2`, otherwise `Label = -1`. The loss turns out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MarginRankLoss operator measures the loss given a pair of input {X1, X2} and the Label with a margin, where Label = 1 indicating X1 is ranked higher than X2, otherwise Label = -1.
The attribute margin helps to make predictions more robust. If the negative item’s prediction exceeds that of the positive item plus a margin, then it contributes to the final loss, otherwise, does not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From this equation, I think you should add "The label can only be +1 or -1" into comments of "Label".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


For batch input, `X1`, `X2` and `Label` all have the same size batch_size x 1.

)DOC");
}
};

class MarginRankLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Activated"),
"Intermediate(Activated) shouldn't be null.");
auto dims = ctx.Input<framework::Tensor>("X1")->dims();
auto *x1_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X1"));
auto *x2_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X2"));
if (x1_grad) {
x1_grad->Resize(dims);
}
if (x2_grad) {
x2_grad->Resize(dims);
}
}
};

} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;

REGISTER_OP(margin_rank_loss, ops::MarginRankLossOp,
ops::MarginRankLossOpMaker<float>, margin_rank_loss_grad,
ops::MarginRankLossGradOp);
REGISTER_OP_CPU_KERNEL(
margin_rank_loss,
ops::MarginRankLossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
margin_rank_loss_grad,
ops::MarginRankLossGradKernel<paddle::platform::CPUPlace, float>);
24 changes: 24 additions & 0 deletions paddle/operators/margin_rank_loss_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/margin_rank_loss_op.h"

namespace ops = paddle::operators;

REGISTER_OP_GPU_KERNEL(
margin_rank_loss,
ops::MarginRankLossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
margin_rank_loss_grad,
ops::MarginRankLossGradKernel<paddle::platform::GPUPlace, float>);
106 changes: 106 additions & 0 deletions paddle/operators/margin_rank_loss_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename T>
struct ReLU {
HOSTDEVICE T operator()(const T& val) const {
if (val < 0) {
return static_cast<T>(0);
} else {
return val;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return val < 0 ? static_cast<T>(0) : val;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
};

template <typename T>
struct Heaviside {
HOSTDEVICE T operator()(const T& val) const {
if (val > 0) {
return static_cast<T>(1);
} else {
return static_cast<T>(0);
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return static_cast<T>(val > 0 ? 1 : 0);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

};

template <typename Place, typename T, typename AttrType = T>
class MarginRankLossKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");
auto* act_t = ctx.Output<framework::Tensor>("Activated");

auto* label_t = ctx.Input<framework::Tensor>("Label");
auto* x1_t = ctx.Input<framework::Tensor>("X1");
auto* x2_t = ctx.Input<framework::Tensor>("X2");

out_t->mutable_data<T>(ctx.GetPlace());
act_t->mutable_data<T>(ctx.GetPlace());

auto margin = static_cast<T>(ctx.Attr<AttrType>("margin"));
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto act = framework::EigenVector<T>::Flatten(*act_t);

auto label = framework::EigenVector<T>::Flatten(*label_t);
auto x1 = framework::EigenVector<T>::Flatten(*x1_t);
auto x2 = framework::EigenVector<T>::Flatten(*x2_t);

auto& dev = ctx.GetEigenDevice<Place>();
out.device(dev) = (-label * (x1 - x2) + margin).unaryExpr(ReLU<T>());
act.device(dev) = out.unaryExpr(Heaviside<T>());
}
};

template <typename Place, typename T>
class MarginRankLossGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_x1_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X1"));
auto* d_x2_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X2"));

auto* act_t = ctx.Input<framework::Tensor>("Activated");
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* label_t = ctx.Input<framework::Tensor>("Label");

auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto act = framework::EigenVector<T>::Flatten(*act_t);
auto label = framework::EigenVector<T>::Flatten(*label_t);
auto& dev = ctx.GetEigenDevice<Place>();

// compute d_x1
if (d_x1_t) {
d_x1_t->mutable_data<T>(ctx.GetPlace());
auto d_x1 = framework::EigenVector<T>::Flatten(*d_x1_t);
d_x1.device(dev) = -d_out * act * label;
}
// compute d_x2
if (d_x2_t) {
d_x2_t->mutable_data<T>(ctx.GetPlace());
auto d_x2 = framework::EigenVector<T>::Flatten(*d_x2_t);
d_x2.device(dev) = d_out * act * label;
}
}
};
} // namespace operators
} // namespace paddle
39 changes: 39 additions & 0 deletions python/paddle/v2/framework/tests/test_margin_rank_loss_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
import numpy as np
from op_test import OpTest


class TestMarginRankLossOp(OpTest):
def setUp(self):
self.op_type = "margin_rank_loss"
batch_size = 5
margin = 0.5
# labels_{i} = {-1, 1}
label = 2 * np.random.randint(
0, 2, size=(batch_size, 1)).astype("float32") - 1
x1 = np.random.random((batch_size, 1)).astype("float32")
x2 = np.random.random((batch_size, 1)).astype("float32")
# loss = max(0, -label * (x1 - x2) + margin)
loss = -label * (x1 - x2) + margin
loss = np.where(loss > 0, loss, 0)
act = np.where(loss > 0, 1., 0.)

self.attrs = {'margin': margin}
self.inputs = {'Label': label, 'X1': x1, 'X2': x2}
self.outputs = {'Activated': act, 'Out': loss}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(["X1", "X2"], "Out")

def test_check_grad_ignore_x1(self):
self.check_grad(["X2"], "Out", no_grad_set=set('X1'))

def test_check_grad_ignore_x2(self):
self.check_grad(["X1"], "Out", no_grad_set=set('X2'))


if __name__ == '__main__':
unittest.main()