Skip to content

Commit 87da154

Browse files
committed
FIx sigmoid_xe_with_logits_op compile
1 parent 63469da commit 87da154

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddle/operators/sigmoid_cross_entropy_with_logits_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace operators {
2121

2222
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
2323
template <typename Place, typename T>
24-
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
24+
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
2525
public:
2626
void Compute(const framework::ExecutionContext &context) const override {
2727
const framework::Tensor *X = context.Input<framework::Tensor>("X");
@@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
4848

4949
// dX = sigmoid(X) - labels
5050
template <typename Place, typename T>
51-
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel {
51+
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
5252
public:
5353
void Compute(const framework::ExecutionContext &context) const override {
5454
const framework::Tensor *X = context.Input<framework::Tensor>("X");

0 commit comments

Comments
 (0)