File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -28,9 +28,13 @@ class AddKernel : public OpKernel {
2828
2929 output->mutable_data <T>(context.GetPlace ());
3030
31- EigenVector<T>::Flatten (*output).device (context.GetEigenDevice <Place>()) =
32- framework::EigenVector<T>::Flatten (*input0) +
33- framework::EigenVector<T>::Flatten (*input1);
31+ auto X = EigenVector<T>::Flatten (*input0);
32+ auto Y = EigenVector<T>::Flatten (*input1);
33+ auto Z = EigenVector<T>::Flatten (*output);
34+
35+ auto place = context.GetEigenDevice <Place>();
36+
37+ Z.device (place) = X + Y;
3438 }
3539};
3640
Original file line number Diff line number Diff line change @@ -27,8 +27,11 @@ class MeanKernel : public OpKernel {
2727
2828 output->mutable_data <T>(context.GetPlace ());
2929
30- EigenScalar<T>::From (*output).device (context.GetEigenDevice <Place>()) =
31- EigenVector<T>::Flatten (*input).mean ();
30+ auto X = EigenVector<T>::Flatten (*input);
31+ auto y = EigenScalar<T>::From (*output);
32+ auto place = context.GetEigenDevice <Place>();
33+
34+ y.device (place) = X.mean ();
3235 }
3336};
3437
Original file line number Diff line number Diff line change @@ -26,13 +26,18 @@ class MulKernel : public OpKernel {
2626 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1 > dim_pair = {
2727 {Eigen::IndexPair<Eigen::DenseIndex>(1 , 0 )}};
2828
29+ auto input0 = context.Input <Tensor>(" X" );
30+ auto input1 = context.Input <Tensor>(" Y" );
2931 auto output = context.Output <Tensor>(0 );
32+
3033 output->mutable_data <T>(context.GetPlace ());
3134
32- EigenMatrix<T>::From (*output).device (context.GetEigenDevice <Place>()) =
33- EigenMatrix<T>::From (*context.Input <Tensor>(" X" ))
34- .contract (EigenMatrix<T>::From (*context.Input <Tensor>(" Y" )),
35- dim_pair);
35+ auto X = EigenMatrix<T>::From (*input0);
36+ auto Y = EigenMatrix<T>::From (*input1);
37+ auto Z = EigenMatrix<T>::From (*output);
38+ auto place = context.GetEigenDevice <Place>();
39+
40+ Z.device (place) = X.contract (Y, dim_pair);
3641 }
3742};
3843} // namespace operators
Original file line number Diff line number Diff line change @@ -29,8 +29,12 @@ class SGDOpKernel : public OpKernel {
2929
3030 param_out->mutable_data <T>(ctx.GetPlace ());
3131
32- EigenVector<T>::Flatten (*param_out).device (ctx.GetEigenDevice <Place>()) =
33- EigenVector<T>::Flatten (*param) - lr * EigenVector<T>::Flatten (*grad);
32+ auto p = EigenVector<T>::Flatten (*param);
33+ auto g = EigenVector<T>::Flatten (*grad);
34+ auto o = EigenVector<T>::Flatten (*param_out);
35+ auto place = ctx.GetEigenDevice <Place>();
36+
37+ o.device (place) = p - lr * g;
3438 }
3539};
3640
Original file line number Diff line number Diff line change @@ -27,8 +27,11 @@ class SigmoidKernel : public OpKernel {
2727 auto output = context.Output <Tensor>(0 );
2828 output->mutable_data <T>(context.GetPlace ());
2929
30- EigenVector<T>::Flatten (*output).device (context.GetEigenDevice <Place>()) =
31- 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten (*input)).exp ());
30+ auto X = EigenVector<T>::Flatten (*input);
31+ auto Y = EigenVector<T>::Flatten (*output);
32+ auto place = context.GetEigenDevice <Place>();
33+
34+ Y.device (place) = 1.0 / (1.0 + (-1.0 * X).exp ());
3235 }
3336};
3437} // namespace operators
You can’t perform that action at this time.
0 commit comments