@@ -13,19 +13,15 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#pragma once
16- #include " paddle/framework/eigen.h"
1716#include " paddle/framework/op_registry.h"
17+ #include " paddle/operators/math/cos_sim_functor.h"
18+ #include " paddle/operators/math/math_function.h"
19+ #include " paddle/platform/for_range.h"
1820
1921namespace paddle {
2022namespace operators {
2123
2224using Tensor = framework::Tensor;
23- template <typename T, int MajorType = Eigen::RowMajor,
24- typename IndexType = Eigen::DenseIndex>
25- using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
26- template <typename T, int MajorType = Eigen::RowMajor,
27- typename IndexType = Eigen::DenseIndex>
28- using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2925
3026template <typename DeviceContext, typename T>
3127class CosSimKernel : public framework ::OpKernel<T> {
@@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel<T> {
4137 out_x_norm->mutable_data <T>(context.GetPlace ());
4238 out_y_norm->mutable_data <T>(context.GetPlace ());
4339
44- // convert Tensor to Eigen Tensor
4540 int rows_x = in_x->dims ()[0 ];
4641 int rows_y = in_y->dims ()[0 ];
47- auto x = EigenMatrix<T>::Reshape (*in_x, 1 );
48- auto y = EigenMatrix<T>::Reshape (*in_y, 1 );
49- auto z = EigenVector<T>::Flatten (*out_z);
50- auto x_norm = EigenVector<T>::Flatten (*out_x_norm);
51- auto y_norm = EigenVector<T>::Flatten (*out_y_norm);
5242
53- // compute
54- auto & place =
55- *context.template device_context <DeviceContext>().eigen_device ();
56- auto row_along = Eigen::array<int , 1 >({{1 }});
57- x_norm.device (place) = x.square ().sum (row_along).sqrt ();
58- y_norm.device (place) = y.square ().sum (row_along).sqrt ();
43+ int cols = framework::product (in_x->dims ()) / rows_x;
44+
5945 if (rows_x == rows_y) {
60- auto xy = (x * y).sum (Eigen::array<int , 1 >({{1 }}));
61- z.device (place) = xy / x_norm / y_norm;
46+ math::CosSimFunctor<T, true > functor (
47+ in_x->data <T>(), in_y->data <T>(), out_x_norm->data <T>(),
48+ out_y_norm->data <T>(), out_z->data <T>(), cols);
49+ platform::ForRange<DeviceContext> for_range (
50+ static_cast <const DeviceContext&>(context.device_context ()), rows_x);
51+ for_range (functor);
6252 } else {
63- Eigen::DSizes<int , 2 > bcast (rows_x, 1 );
64- auto xy = (x * y.broadcast (bcast)).sum (row_along);
65- z.device (place) = xy / x_norm / y_norm.broadcast (bcast);
53+ math::CosSimFunctor<T, false > functor (
54+ in_x->data <T>(), in_y->data <T>(), out_x_norm->data <T>(),
55+ out_y_norm->data <T>(), out_z->data <T>(), cols);
56+ platform::ForRange<DeviceContext> for_range (
57+ static_cast <const DeviceContext&>(context.device_context ()), rows_x);
58+ for_range (functor);
6659 }
6760 }
6861};
@@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel<T> {
8174 auto * out_grad_y = context.Output <Tensor>(framework::GradVarName (" Y" ));
8275 auto * in_grad_z = context.Input <Tensor>(framework::GradVarName (" Out" ));
8376
84- // convert Tensor to Eigen Tensor
85- auto x = EigenMatrix<T>::Reshape (*in_x, 1 );
86- auto y = EigenMatrix<T>::Reshape (*in_y, 1 );
87- auto z = EigenMatrix<T>::Reshape (*in_z, 1 );
88- auto x_norm = EigenMatrix<T>::Reshape (*in_x_norm, 1 );
89- auto y_norm = EigenMatrix<T>::Reshape (*in_y_norm, 1 );
90- auto dz = EigenMatrix<T>::Reshape (*in_grad_z, 1 );
91-
9277 // compute gradident
9378 int rows_x = in_x->dims ()[0 ];
9479 int rows_y = in_y->dims ()[0 ];
9580 int cols = framework::product (in_x->dims ()) / rows_x;
96- Eigen::DSizes<int , 2 > bcast_cols (1 , cols);
97- auto z_bcast = z.broadcast (bcast_cols);
98- auto dz_bcast = dz.broadcast (bcast_cols);
99- auto x_snorm_bcast = x_norm.square ().eval ().broadcast (bcast_cols);
100- auto & place =
101- *context.template device_context <DeviceContext>().eigen_device ();
81+
10282 if (rows_x == rows_y) {
103- auto y_snorm_bcast = y_norm.square ().eval ().broadcast (bcast_cols);
104- auto norm_prod_bcast = (x_norm * y_norm).eval ().broadcast (bcast_cols);
105- // compute dx
10683 if (out_grad_x) {
107- out_grad_x->mutable_data <T>(context.GetPlace ());
108- auto dx = EigenMatrix<T>::Reshape (*out_grad_x, 1 );
109- auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
110- dx.device (place) = dz_bcast * grad;
84+ math::CosSimGradFunctor<T> functor (
85+ in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
86+ in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
87+ out_grad_x->mutable_data <T>(context.GetPlace ()), cols);
88+ platform::ForRange<DeviceContext> for_range (
89+ static_cast <const DeviceContext&>(context.device_context ()),
90+ rows_x);
91+ for_range (functor);
11192 }
112- // compute dy
11393 if (out_grad_y) {
114- out_grad_y->mutable_data <T>(context.GetPlace ());
115- auto dy = EigenMatrix<T>::Reshape (*out_grad_y, 1 );
116- auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
117- dy.device (place) = dz_bcast * grad;
94+ math::CosSimGradFunctor<T> functor (
95+ in_y_norm->data <T>(), in_x_norm->data <T>(), in_y->data <T>(),
96+ in_x->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
97+ out_grad_y->mutable_data <T>(context.GetPlace ()), cols);
98+ platform::ForRange<DeviceContext> for_range (
99+ static_cast <const DeviceContext&>(context.device_context ()),
100+ rows_x);
101+ for_range (functor);
118102 }
119103 } else {
120- Eigen::DSizes<int , 2 > bcast_rows (rows_x, 1 );
121- Eigen::DSizes<int , 2 > bcast_rows_cols (rows_x, cols);
122- auto y_bcast = y.broadcast (bcast_rows);
123- auto y_snorm_bcast = y_norm.square ().eval ().broadcast (bcast_rows_cols);
124- auto norm_prod_bcast = (x_norm * y_norm.eval ().broadcast (bcast_rows))
125- .eval ()
126- .broadcast (bcast_cols);
127- // compute dx
128104 if (out_grad_x) {
129- out_grad_x->mutable_data <T>(context.GetPlace ());
130- auto dx = EigenMatrix<T>::Reshape (*out_grad_x, 1 );
131- auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
132- dx.device (place) = dz_bcast * grad;
105+ math::CosSimDxFunctor<T> functor (
106+ in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
107+ in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
108+ out_grad_x->mutable_data <T>(context.GetPlace ()), cols);
109+ platform::ForRange<DeviceContext> for_range (
110+ static_cast <const DeviceContext&>(context.device_context ()),
111+ rows_x);
112+ for_range (functor);
133113 }
134- // compute dy
135114 if (out_grad_y) {
136115 out_grad_y->mutable_data <T>(context.GetPlace ());
137- auto dy = EigenVector<T>::Flatten (*out_grad_y);
138- auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
139- dy.device (place) = (dz_bcast * grad).sum (Eigen::array<int , 1 >({{0 }}));
116+ math::SetConstant<DeviceContext, T> set_zero;
117+ auto & dev_ctx = context.template device_context <DeviceContext>();
118+ set_zero (dev_ctx, out_grad_y, static_cast <T>(0 ));
119+
120+ math::CosSimDyFunctor<DeviceContext, T> functor;
121+ functor (dev_ctx, in_x_norm->data <T>(), in_y_norm->data <T>(),
122+ in_x->data <T>(), in_y->data <T>(), in_z->data <T>(),
123+ in_grad_z->data <T>(), static_cast <size_t >(rows_x),
124+ static_cast <size_t >(cols), out_grad_y->data <T>());
140125 }
141126 }
142127 }
0 commit comments