Skip to content

Commit 28db149

Browse files
authored
Merge pull request #3179 from gangliao/eigen_refine
Refine compute code in operators
2 parents c46aed5 + 43528c4 commit 28db149

5 files changed

Lines changed: 32 additions & 13 deletions

File tree

paddle/operators/add_op.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ class AddKernel : public OpKernel {
2828

2929
output->mutable_data<T>(context.GetPlace());
3030

31-
EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
32-
framework::EigenVector<T>::Flatten(*input0) +
33-
framework::EigenVector<T>::Flatten(*input1);
31+
auto X = EigenVector<T>::Flatten(*input0);
32+
auto Y = EigenVector<T>::Flatten(*input1);
33+
auto Z = EigenVector<T>::Flatten(*output);
34+
35+
auto place = context.GetEigenDevice<Place>();
36+
37+
Z.device(place) = X + Y;
3438
}
3539
};
3640

paddle/operators/mean_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ class MeanKernel : public OpKernel {
2727

2828
output->mutable_data<T>(context.GetPlace());
2929

30-
EigenScalar<T>::From(*output).device(context.GetEigenDevice<Place>()) =
31-
EigenVector<T>::Flatten(*input).mean();
30+
auto X = EigenVector<T>::Flatten(*input);
31+
auto y = EigenScalar<T>::From(*output);
32+
auto place = context.GetEigenDevice<Place>();
33+
34+
y.device(place) = X.mean();
3235
}
3336
};
3437

paddle/operators/mul_op.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@ class MulKernel : public OpKernel {
2626
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
2727
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
2828

29+
auto input0 = context.Input<Tensor>("X");
30+
auto input1 = context.Input<Tensor>("Y");
2931
auto output = context.Output<Tensor>(0);
32+
3033
output->mutable_data<T>(context.GetPlace());
3134

32-
EigenMatrix<T>::From(*output).device(context.GetEigenDevice<Place>()) =
33-
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
34-
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
35-
dim_pair);
35+
auto X = EigenMatrix<T>::From(*input0);
36+
auto Y = EigenMatrix<T>::From(*input1);
37+
auto Z = EigenMatrix<T>::From(*output);
38+
auto place = context.GetEigenDevice<Place>();
39+
40+
Z.device(place) = X.contract(Y, dim_pair);
3641
}
3742
};
3843
} // namespace operators

paddle/operators/sgd_op.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ class SGDOpKernel : public OpKernel {
2929

3030
param_out->mutable_data<T>(ctx.GetPlace());
3131

32-
EigenVector<T>::Flatten(*param_out).device(ctx.GetEigenDevice<Place>()) =
33-
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
32+
auto p = EigenVector<T>::Flatten(*param);
33+
auto g = EigenVector<T>::Flatten(*grad);
34+
auto o = EigenVector<T>::Flatten(*param_out);
35+
auto place = ctx.GetEigenDevice<Place>();
36+
37+
o.device(place) = p - lr * g;
3438
}
3539
};
3640

paddle/operators/sigmoid_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ class SigmoidKernel : public OpKernel {
2727
auto output = context.Output<Tensor>(0);
2828
output->mutable_data<T>(context.GetPlace());
2929

30-
EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
31-
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
30+
auto X = EigenVector<T>::Flatten(*input);
31+
auto Y = EigenVector<T>::Flatten(*output);
32+
auto place = context.GetEigenDevice<Place>();
33+
34+
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
3235
}
3336
};
3437
} // namespace operators

0 commit comments

Comments
 (0)