1515#pragma once
1616#include " paddle/framework/eigen.h"
1717#include " paddle/framework/op_registry.h"
18- #include " paddle/operators/elementwise_add_op .h"
18+ #include " paddle/operators/elementwise_op_function .h"
1919
2020namespace paddle {
2121namespace operators {
@@ -28,27 +28,73 @@ template <typename T, int MajorType = Eigen::RowMajor,
2828 typename IndexType = Eigen::DenseIndex>
2929using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
3030
31- template <typename T, typename DeviceContext>
32- void Function_forward (T* out, T* x_norm, T* y_norm,
33- ElementIterator<T, DeviceContext>& x,
34- ElementIterator<T, DeviceContext>& y, int row, int col) {
35- for (int i = 0 ; i < row; ++i) {
36- T xx = 0 ;
31+ template <typename IT1, typename IT2, typename Callback>
32+ static void ForEachZip (IT1 begin1, IT1 last1, IT2 begin2, Callback callback) {
33+ // This method could be implemented in CUDA
34+ for (; begin1 < last1; ++begin1, ++begin2) {
35+ callback (*begin1, *begin2);
36+ }
37+ }
38+
39+ template <typename T, bool same_row>
40+ struct CosSimFunctor {
41+ CosSimFunctor (const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
42+ : x_norm_(x_norm),
43+ y_norm_ (y_norm),
44+ x_(x),
45+ y_(y),
46+ z_(z),
47+ cols_(static_cast <size_t >(cols)) {}
48+
49+ inline void operator ()(T& x_norm, T& y_norm) const {
50+ size_t x_offset = &x_norm - x_norm_;
51+ size_t y_offset = &y_norm - y_norm_;
52+
53+ auto * x = x_ + cols_ * x_offset;
54+
55+ T xx = 0 , xy = 0 ;
3756 T yy = 0 ;
38- T xy = 0 ;
39- for (int j = 0 ; j < col; ++j) {
40- xy += (*x) * (*y);
41- xx += (*x) * (*x);
42- yy += (*y) * (*y);
43- ++y;
44- ++x;
57+ if (same_row) {
58+ auto * y = y_ + cols_ * y_offset;
59+ for (size_t i = 0 ; i < cols_; ++i) {
60+ xx += x[i] * x[i];
61+ yy += y[i] * y[i];
62+ xy += x[i] * y[i];
63+ }
64+ xx = sqrt (xx);
65+ yy = sqrt (yy);
66+ x_norm_[x_offset] = xx;
67+ y_norm_[y_offset] = yy;
68+ z_[x_offset] = xy / (xx * yy);
69+ } else {
70+ auto * y = y_;
71+ // if (yy == -1) {
72+ // yy = 0;
73+ // for (size_t i = 0; i < cols_; ++i) {
74+ // yy += y[i] * y[i];
75+ // }
76+ // y_norm[0] = sqrt(yy);
77+ // }
78+ for (size_t i = 0 ; i < cols_; ++i) {
79+ xx += x[i] * x[i];
80+ yy += y[i] * y[i]; // only need
81+ xy += x[i] * y[i];
82+ }
83+ xx = sqrt (xx);
84+ yy = sqrt (yy);
85+ x_norm_[x_offset] = xx;
86+ y_norm_[0 ] = yy;
87+ z_[x_offset] = xy / (xx * yy);
4588 }
46- x_norm[i] = sqrt (xx);
47- y_norm[i] = sqrt (yy);
48-
49- out[i] = xy / (x_norm[i] * y_norm[i]);
5089 }
51- }
90+
91+ T* x_norm_;
92+ T* y_norm_;
93+ const T* x_;
94+ const T* y_;
95+ T* z_;
96+ const size_t cols_;
97+ };
5298
5399template <typename DeviceContext, typename T>
54100class CosSimKernel : public framework ::OpKernel<T> {
@@ -68,58 +114,140 @@ class CosSimKernel : public framework::OpKernel<T> {
68114 int rows_y = in_y->dims ()[0 ];
69115
70116 int cols = framework::product (in_x->dims ()) / rows_x;
71- auto x_iter = ElementIterator<T, DeviceContext>(in_x->data <T>(), rows_x,
72- cols, rows_x, cols);
73- auto y_iter = ElementIterator<T, DeviceContext>(in_y->data <T>(), rows_y,
74- cols, rows_x, cols);
75-
76- Function_forward (out_z->data <T>(), out_x_norm->data <T>(),
77- out_y_norm->data <T>(), x_iter, y_iter, rows_x, cols);
78- //
79- // // convert Tensor to Eigen Tensor
80- // // int rows_x = in_x->dims()[0];
81- // // int rows_y = in_y->dims()[0];
82- // auto x = EigenMatrix<T>::Reshape(*in_x, 1);
83- // auto y = EigenMatrix<T>::Reshape(*in_y, 1);
84- // auto z = EigenVector<T>::Flatten(*out_z);
85- // auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
86- // auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
87- //
88- // // compute
89- // auto& place =
90- // *context.template device_context<DeviceContext>().eigen_device();
91- // auto row_along = Eigen::array<int, 1>({{1}});
92- // x_norm.device(place) = x.square().sum(row_along).sqrt();
93- // y_norm.device(place) = y.square().sum(row_along).sqrt();
94- // if (rows_x == rows_y) {
95- // auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
96- // z.device(place) = xy / x_norm / y_norm;
97- // } else {
98- // Eigen::DSizes<int, 2> bcast(rows_x, 1);
99- // auto xy = (x * y.broadcast(bcast)).sum(row_along);
100- // z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
101- // }
117+
118+ if (rows_x == rows_y) {
119+ CosSimFunctor<T, true > functor (
120+ in_x->data <T>(), in_y->data <T>(), out_x_norm->data <T>(),
121+ out_y_norm->data <T>(), out_z->data <T>(), cols);
122+ ForEachZip (out_x_norm->data <T>(), out_x_norm->data <T>() + rows_x,
123+ out_y_norm->data <T>(), functor);
124+ } else {
125+ CosSimFunctor<T, false > functor (
126+ in_x->data <T>(), in_y->data <T>(), out_x_norm->data <T>(),
127+ out_y_norm->data <T>(), out_z->data <T>(), cols);
128+ ForEachZip (out_x_norm->data <T>(), out_x_norm->data <T>() + rows_x,
129+ out_y_norm->data <T>(), functor);
130+ }
102131 }
103132};
104133
105- template <typename T, typename DeviceContext>
106- void Function_element (T* result, ElementIterator<T, DeviceContext> dz,
107- ElementIterator<T, DeviceContext> y,
108- ElementIterator<T, DeviceContext> x_norm,
109- ElementIterator<T, DeviceContext> y_norm,
110- ElementIterator<T, DeviceContext> z,
111- ElementIterator<T, DeviceContext> x, int num, int block) {
112- for (int i = 0 ; i < num; ++i) {
113- result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) -
114- (*z) * (*x) / ((*x_norm) * (*x_norm)));
115- ++dz;
116- ++y;
117- ++x_norm;
118- ++y_norm;
119- ++z;
120- ++x;
134+ template <typename T>
135+ struct CosSimGradFunctor {
136+ CosSimGradFunctor (const T* x_norm, const T* y_norm, const T* x, const T* y,
137+ const T* z, const T* dz, T* dx, int cols)
138+ : x_norm_(x_norm),
139+ y_norm_ (y_norm),
140+ x_(x),
141+ y_(y),
142+ z_(z),
143+ dz_(dz),
144+ dx_(dx),
145+ cols_(static_cast <size_t >(cols)) {}
146+
147+ void operator ()(const T& x_norm, const T& y_norm) const {
148+ size_t x_offset = &x_norm - x_norm_;
149+ size_t y_offset = &y_norm - y_norm_;
150+
151+ auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
152+ // auto y_norm_square = y_norm_[y_offset] * y_norm_[y_offset];
153+ auto xy_norm_prod = x_norm_[x_offset] * y_norm_[y_offset];
154+ auto dz = dz_[x_offset];
155+
156+ auto * dx = dx_ + cols_ * x_offset;
157+ auto * x = x_ + cols_ * x_offset;
158+ auto * y = y_ + cols_ * y_offset;
159+ auto z = z_[x_offset];
160+
161+ for (size_t i = 0 ; i < cols_; ++i) {
162+ dx[i] = dz * (y[i] / xy_norm_prod - z * x[i] / x_norm_square);
163+ }
121164 }
122- }
165+
166+ const T* x_norm_;
167+ const T* y_norm_;
168+ const T* x_;
169+ const T* y_;
170+ const T* z_;
171+ const T* dz_;
172+ T* dx_;
173+ const size_t cols_;
174+ };
175+
176+ template <typename T>
177+ struct CosSimDxFunctor {
178+ CosSimDxFunctor (const T* x_norm, const T* y_norm, const T* x, const T* y,
179+ const T* z, const T* dz, T* dx, int cols)
180+ : x_norm_(x_norm),
181+ y_norm_ (y_norm),
182+ x_(x),
183+ y_(y),
184+ z_(z),
185+ dz_(dz),
186+ dx_(dx),
187+ cols_(static_cast <size_t >(cols)) {}
188+
189+ void operator ()(const T& x_norm, const T& y_norm) const {
190+ size_t x_offset = &x_norm - x_norm_;
191+
192+ auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
193+ auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0 ];
194+ auto dz = dz_[x_offset];
195+ auto z = z_[x_offset];
196+
197+ auto * dx = dx_ + cols_ * x_offset;
198+ auto * x = x_ + cols_ * x_offset;
199+
200+ for (size_t i = 0 ; i < cols_; ++i) {
201+ dx[i] = dz * (y_[i] / xy_norm_prod - z * x[i] / x_norm_square);
202+ }
203+ }
204+
205+ const T* x_norm_;
206+ const T* y_norm_;
207+ const T* x_;
208+ const T* y_;
209+ const T* z_;
210+ const T* dz_;
211+ T* dx_;
212+ const size_t cols_;
213+ };
214+
215+ template <typename T>
216+ struct CosSimDyFunctor {
217+ CosSimDyFunctor (const T* x_norm, const T* y_norm, const T* x, const T* y,
218+ const T* z, const T* dz, T* dy, int cols)
219+ : x_norm_(x_norm),
220+ y_norm_ (y_norm),
221+ x_(x),
222+ y_(y),
223+ z_(z),
224+ dz_(dz),
225+ dy_(dy),
226+ cols_(static_cast <size_t >(cols)) {}
227+
228+ void operator ()(const T& x_norm, const T& y_norm) const {
229+ size_t x_offset = &x_norm - x_norm_;
230+
231+ auto y_norm_square = y_norm_[0 ] * y_norm_[0 ];
232+ auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0 ];
233+ auto dz = dz_[x_offset];
234+ auto z = z_[x_offset];
235+ auto * x = x_ + cols_ * x_offset;
236+
237+ for (size_t i = 0 ; i < cols_; ++i) {
238+ dy_[i] += dz * (x[i] / xy_norm_prod - z * y_[i] / y_norm_square);
239+ }
240+ }
241+
242+ const T* x_norm_;
243+ const T* y_norm_;
244+ const T* x_;
245+ const T* y_;
246+ const T* z_;
247+ const T* dz_;
248+ T* dy_;
249+ const size_t cols_;
250+ };
123251
124252template <typename DeviceContext, typename T>
125253class CosSimGradKernel : public framework ::OpKernel<T> {
@@ -140,45 +268,40 @@ class CosSimGradKernel : public framework::OpKernel<T> {
140268 int rows_y = in_y->dims ()[0 ];
141269 int cols = framework::product (in_x->dims ()) / rows_x;
142270
143- // ////////////////////////////
144- // ##
145- auto x_iter = ElementIterator<T, DeviceContext>(in_x->data <T>(), rows_x,
146- cols, rows_x, cols);
147- auto y_iter = ElementIterator<T, DeviceContext>(in_y->data <T>(), rows_y,
148- cols, rows_x, cols);
149- auto z_iter = ElementIterator<T, DeviceContext>(in_z->data <T>(), rows_x, 1 ,
150- rows_x, cols);
151- auto dz_iter = ElementIterator<T, DeviceContext>(in_grad_z->data <T>(),
152- rows_x, 1 , rows_x, cols);
153- auto x_norm_iter = ElementIterator<T, DeviceContext>(
154- in_x_norm->data <T>(), rows_x, 1 , rows_x, cols);
155- auto y_norm_iter = ElementIterator<T, DeviceContext>(
156- in_y_norm->data <T>(), rows_y, 1 , rows_x, cols);
157- // ##
158- // ////////////////////////////
159- // compute dx
160- if (out_grad_x) {
161- out_grad_x->mutable_data <T>(context.GetPlace ());
162-
163- // ////////////////////////////
164- // ##
165- Function_element (out_grad_x->data <T>(), dz_iter, y_iter, x_norm_iter,
166- y_norm_iter, z_iter, x_iter, rows_x * cols,
167- rows_x * cols);
168- // ##
169- // ////////////////////////////
170- }
171- // compute dy
172- if (out_grad_y) {
173- out_grad_y->mutable_data <T>(context.GetPlace ());
174-
175- // ////////////////////////////
176- // ##
177- Function_element (out_grad_y->data <T>(), dz_iter, x_iter, y_norm_iter,
178- x_norm_iter, z_iter, y_iter, rows_x * cols,
179- rows_y * cols);
180- // ##
181- // ////////////////////////////
271+ if (rows_x == rows_y) {
272+ if (out_grad_x) {
273+ CosSimGradFunctor<T> functor (
274+ in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
275+ in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
276+ out_grad_x->mutable_data <T>(context.GetPlace ()), cols);
277+ ForEachZip (in_x_norm->data <T>(), in_x_norm->data <T>() + rows_x,
278+ in_y_norm->data <T>(), functor);
279+ }
280+ if (out_grad_y) {
281+ CosSimGradFunctor<T> functor (
282+ in_y_norm->data <T>(), in_x_norm->data <T>(), in_y->data <T>(),
283+ in_x->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
284+ out_grad_y->mutable_data <T>(context.GetPlace ()), cols);
285+ ForEachZip (in_y_norm->data <T>(), in_y_norm->data <T>() + rows_x,
286+ in_x_norm->data <T>(), functor);
287+ }
288+ } else {
289+ if (out_grad_x) {
290+ CosSimDxFunctor<T> functor (
291+ in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
292+ in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
293+ out_grad_x->mutable_data <T>(context.GetPlace ()), cols);
294+ ForEachZip (in_x_norm->data <T>(), in_x_norm->data <T>() + rows_x,
295+ in_y_norm->data <T>(), functor);
296+ }
297+ if (out_grad_y) {
298+ CosSimDyFunctor<T> functor (
299+ in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
300+ in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
301+ out_grad_y->mutable_data <T>(context.GetPlace ()), cols);
302+ ForEachZip (in_x_norm->data <T>(), in_x_norm->data <T>() + rows_x,
303+ in_y_norm->data <T>(), functor);
304+ }
182305 }
183306 }
184307};
0 commit comments