Skip to content

Commit b9336e6

Browse files
authored
Adding support for the sigmoid_cross_entropy_with_logits operator (#4448)
* Adding support for the sigmoid_cross_entropy_with_logits operator * Fixing a typo in the cuda file * Adding Python documentation for sigmoid_cross_entropy_with_logits_op * Correcting typos in documentation * Adding unit tests for sigmoid_cross_entropy_with_logits_op * Addressing code review feedback
1 parent ecef2e6 commit b9336e6

File tree

4 files changed

+315
-0
lines changed

4 files changed

+315
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/sigmoid_cross_entropy_with_logits_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::Tensor;
21+
22+
class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
protected:
27+
void InferShape(framework::InferShapeContextBase* ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
29+
PADDLE_ENFORCE(ctx->HasInput("Labels"),
30+
"Input(Labels) should be not null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
32+
33+
auto x_dims = ctx->GetInputDim("X");
34+
auto labels_dims = ctx->GetInputDim("Labels");
35+
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
36+
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
37+
"Input(Labels)'s rank should be 2.");
38+
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
39+
"The 1st dimension of Input(X) and Input(Labels) should "
40+
"be equal.");
41+
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
42+
"The 2nd dimension of Input(X) and Input(Labels) should "
43+
"be equal.");
44+
45+
ctx->SetOutputDim("Out", x_dims);
46+
ctx->ShareLoD("X", /*->*/ "Out");
47+
}
48+
};
49+
50+
class SigmoidCrossEntropyWithLogitsGradOp
51+
: public framework::OperatorWithKernel {
52+
public:
53+
using framework::OperatorWithKernel::OperatorWithKernel;
54+
55+
protected:
56+
void InferShape(framework::InferShapeContextBase* ctx) const override {
57+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
58+
PADDLE_ENFORCE(ctx->HasInput("Labels"),
59+
"Input(Labels) should be not null.");
60+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
61+
"Input(Out@GRAD) shoudl be not null.");
62+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
63+
"Output(X@GRAD) should be not null.");
64+
65+
auto x_dims = ctx->GetInputDim("X");
66+
auto labels_dims = ctx->GetInputDim("Labels");
67+
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
68+
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
69+
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
70+
"Input(Labels)'s rank should be 2.");
71+
PADDLE_ENFORCE_EQ(dout_dims.size(), 2,
72+
"Input(Out@Grad)'s rank should be 2.");
73+
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
74+
"The 1st dimension of Input(X) and Input(Labels) should "
75+
"be equal.");
76+
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
77+
"The 2nd dimension of Input(X) and Input(Labels) should "
78+
"be equal.");
79+
PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0],
80+
"The 1st dimension of Input(X) and Input(Out@Grad) "
81+
"should be equal.");
82+
PADDLE_ENFORCE_EQ(x_dims[1], dout_dims[1],
83+
"The 2nd dimension of Input(X) and Input(Out@Grad) "
84+
"should be equal.");
85+
86+
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
87+
}
88+
};
89+
90+
class SigmoidCrossEntropyWithLogitsOpMaker
91+
: public framework::OpProtoAndCheckerMaker {
92+
public:
93+
SigmoidCrossEntropyWithLogitsOpMaker(framework::OpProto* proto,
94+
framework::OpAttrChecker* op_checker)
95+
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
96+
AddInput("X",
97+
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
98+
"where N is the batch size and D is the number of classes. "
99+
"This input is a tensor of logits computed by the previous "
100+
" operator. Logits are unscaled log probabilities given as "
101+
"log(p/(1-p)).");
102+
AddInput("Labels",
103+
"(Tensor, default Tensor<float>), a 2-D tensor of the same type "
104+
"and shape as X. This input is a tensor of probabalistic labels "
105+
"for each logit");
106+
AddOutput("Out",
107+
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D "
108+
" of elementwise logistic losses.");
109+
AddComment(R"DOC(
110+
SigmoidCrossEntropyWithLogits Operator.
111+
112+
This measures the elementwise probability error in discrete classification tasks
113+
in which each class is independent. This can be thought of as predicting labels
114+
for a data-point that are not mutually exclusive. For example, a news article
115+
can be about politics, technology or sports at the same time or none of these.
116+
117+
The logistic loss is given as follows:
118+
119+
loss = -Labels * log(sigmoid(X)) - (1 - Labels) * log(1 - sigmoid(X))
120+
121+
We know that sigmoid(X) = (1 / (1 + exp(-X))). By substituting this we get
122+
123+
loss = X - X * Labels + log(1 + exp(-X))
124+
125+
For stability and to prevent overflow of exp(-X) when X < 0,
126+
we can reformulate the loss as follows:
127+
128+
loss = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
129+
130+
Both the input `X` and `Labels` can carry the LoD (Level of Details) information.
131+
However the output only shares the LoD with input `X`.
132+
)DOC");
133+
}
134+
};
135+
136+
} // namespace operators
137+
} // namespace paddle
138+
139+
namespace ops = paddle::operators;
140+
REGISTER_OP(sigmoid_cross_entropy_with_logits,
141+
ops::SigmoidCrossEntropyWithLogitsOp,
142+
ops::SigmoidCrossEntropyWithLogitsOpMaker,
143+
sigmoid_cross_entropy_with_logits_grad,
144+
ops::SigmoidCrossEntropyWithLogitsGradOp);
145+
REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits,
146+
ops::SigmoidCrossEntropyWithLogitsKernel<
147+
paddle::platform::CPUPlace, float>);
148+
REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits_grad,
149+
ops::SigmoidCrossEntropyWithLogitsGradKernel<
150+
paddle::platform::CPUPlace, float>);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/sigmoid_cross_entropy_with_logits_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(sigmoid_cross_entropy_with_logits,
20+
ops::SigmoidCrossEntropyWithLogitsKernel<
21+
paddle::platform::GPUPlace, float>);
22+
REGISTER_OP_GPU_KERNEL(sigmoid_cross_entropy_with_logits_grad,
23+
ops::SigmoidCrossEntropyWithLogitsGradKernel<
24+
paddle::platform::GPUPlace, float>);
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
23+
template <typename Place, typename T>
24+
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
25+
public:
26+
void Compute(const framework::ExecutionContext &context) const override {
27+
const framework::Tensor *X = context.Input<framework::Tensor>("X");
28+
const framework::Tensor *Labels =
29+
context.Input<framework::Tensor>("Labels");
30+
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
31+
Out->mutable_data<T>(context.GetPlace());
32+
33+
auto x = framework::EigenVector<T>::Flatten(*X);
34+
auto labels = framework::EigenVector<T>::Flatten(*Labels);
35+
auto out = framework::EigenVector<T>::Flatten(*Out);
36+
auto place = context.GetEigenDevice<Place>();
37+
38+
// term1 = max(x, 0)
39+
auto term1 = x.cwiseMax(static_cast<T>(0));
40+
// term2 = x * labels
41+
auto term2 = x * labels;
42+
// term3 = log(1 + exp(-abs(x)))
43+
auto term3 = (static_cast<T>(1) + (-(x.abs())).exp()).log();
44+
45+
out.device(place) = term1 - term2 + term3;
46+
}
47+
};
48+
49+
// dX = sigmoid(X) - labels
50+
template <typename Place, typename T>
51+
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel {
52+
public:
53+
void Compute(const framework::ExecutionContext &context) const override {
54+
const framework::Tensor *X = context.Input<framework::Tensor>("X");
55+
const framework::Tensor *Labels =
56+
context.Input<framework::Tensor>("Labels");
57+
const framework::Tensor *dOut =
58+
context.Input<framework::Tensor>(framework::GradVarName("Out"));
59+
framework::Tensor *dX =
60+
context.Output<framework::Tensor>(framework::GradVarName("X"));
61+
dX->mutable_data<T>(context.GetPlace());
62+
63+
auto x = framework::EigenVector<T>::Flatten(*X);
64+
auto labels = framework::EigenVector<T>::Flatten(*Labels);
65+
auto dout = framework::EigenVector<T>::Flatten(*dOut);
66+
auto dx = framework::EigenVector<T>::Flatten(*dX);
67+
auto place = context.GetEigenDevice<Place>();
68+
69+
auto sigmoid_x = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
70+
dx.device(place) = dout * (sigmoid_x - labels);
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
from op_test import OpTest
3+
from scipy.special import logit
4+
from scipy.special import expit
5+
6+
7+
class TestSigmoidCrossEntropyWithLogitsOp1(OpTest):
8+
'''Test sigmoid_cross_entropy_with_logit_op with binary labels
9+
'''
10+
11+
def setUp(self):
12+
self.op_type = "sigmoid_cross_entropy_with_logits"
13+
batch_size = 64
14+
num_classes = 20
15+
self.inputs = {
16+
'X': logit(
17+
np.random.uniform(0, 1, (batch_size, num_classes))
18+
.astype("float32")),
19+
'Labels': np.random.randint(0, 2, (batch_size, num_classes))
20+
.astype("float32")
21+
}
22+
23+
# Fw Pass is implemented as elementwise sigmoid followed by
24+
# elementwise logistic loss
25+
# Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X))
26+
sigmoid_X = expit(self.inputs['X'])
27+
term1 = self.inputs['Labels'] * np.log(sigmoid_X)
28+
term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X)
29+
self.outputs = {'Out': -term1 - term2}
30+
31+
def test_check_output(self):
32+
self.check_output()
33+
34+
def test_check_grad(self):
35+
self.check_grad(['X'], 'Out')
36+
37+
38+
class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
39+
'''Test sigmoid_cross_entropy_with_logit_op with probabalistic labels
40+
'''
41+
42+
def setUp(self):
43+
self.op_type = "sigmoid_cross_entropy_with_logits"
44+
batch_size = 64
45+
num_classes = 20
46+
self.inputs = {
47+
'X': logit(
48+
np.random.uniform(0, 1, (batch_size, num_classes))
49+
.astype("float32")),
50+
'Labels': np.random.uniform(0, 1, (batch_size, num_classes))
51+
.astype("float32")
52+
}
53+
54+
# Fw Pass is implemented as elementwise sigmoid followed by
55+
# elementwise logistic loss
56+
# Labels * -log(sigmoid(X)) + (1 - labels) * -log(1 - sigmoid(X))
57+
sigmoid_X = expit(self.inputs['X'])
58+
term1 = self.inputs['Labels'] * np.log(sigmoid_X)
59+
term2 = (1 - self.inputs['Labels']) * np.log(1 - sigmoid_X)
60+
self.outputs = {'Out': -term1 - term2}
61+
62+
def test_check_output(self):
63+
self.check_output()
64+
65+
def test_check_grad(self):
66+
self.check_grad(['X'], 'Out')

0 commit comments

Comments
 (0)