Skip to content

Commit bcf0b56

Browse files
committed
refine iterator
1 parent 784740d commit bcf0b56

File tree

2 files changed

+229
-161
lines changed

2 files changed

+229
-161
lines changed

paddle/operators/cos_sim_op.h

Lines changed: 229 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#pragma once
1616
#include "paddle/framework/eigen.h"
1717
#include "paddle/framework/op_registry.h"
18-
#include "paddle/operators/elementwise_add_op.h"
18+
#include "paddle/operators/elementwise_op_function.h"
1919

2020
namespace paddle {
2121
namespace operators {
@@ -28,27 +28,73 @@ template <typename T, int MajorType = Eigen::RowMajor,
2828
typename IndexType = Eigen::DenseIndex>
2929
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
3030

31-
template <typename T, typename DeviceContext>
32-
void Function_forward(T* out, T* x_norm, T* y_norm,
33-
ElementIterator<T, DeviceContext>& x,
34-
ElementIterator<T, DeviceContext>& y, int row, int col) {
35-
for (int i = 0; i < row; ++i) {
36-
T xx = 0;
31+
template <typename IT1, typename IT2, typename Callback>
32+
static void ForEachZip(IT1 begin1, IT1 last1, IT2 begin2, Callback callback) {
33+
// This method could be implemented in CUDA
34+
for (; begin1 < last1; ++begin1, ++begin2) {
35+
callback(*begin1, *begin2);
36+
}
37+
}
38+
39+
template <typename T, bool same_row>
40+
struct CosSimFunctor {
41+
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
42+
: x_norm_(x_norm),
43+
y_norm_(y_norm),
44+
x_(x),
45+
y_(y),
46+
z_(z),
47+
cols_(static_cast<size_t>(cols)) {}
48+
49+
inline void operator()(T& x_norm, T& y_norm) const {
50+
size_t x_offset = &x_norm - x_norm_;
51+
size_t y_offset = &y_norm - y_norm_;
52+
53+
auto* x = x_ + cols_ * x_offset;
54+
55+
T xx = 0, xy = 0;
3756
T yy = 0;
38-
T xy = 0;
39-
for (int j = 0; j < col; ++j) {
40-
xy += (*x) * (*y);
41-
xx += (*x) * (*x);
42-
yy += (*y) * (*y);
43-
++y;
44-
++x;
57+
if (same_row) {
58+
auto* y = y_ + cols_ * y_offset;
59+
for (size_t i = 0; i < cols_; ++i) {
60+
xx += x[i] * x[i];
61+
yy += y[i] * y[i];
62+
xy += x[i] * y[i];
63+
}
64+
xx = sqrt(xx);
65+
yy = sqrt(yy);
66+
x_norm_[x_offset] = xx;
67+
y_norm_[y_offset] = yy;
68+
z_[x_offset] = xy / (xx * yy);
69+
} else {
70+
auto* y = y_;
71+
// if (yy == -1) {
72+
// yy = 0;
73+
// for (size_t i = 0; i < cols_; ++i) {
74+
// yy += y[i] * y[i];
75+
// }
76+
// y_norm[0] = sqrt(yy);
77+
// }
78+
for (size_t i = 0; i < cols_; ++i) {
79+
xx += x[i] * x[i];
80+
yy += y[i] * y[i]; // only need
81+
xy += x[i] * y[i];
82+
}
83+
xx = sqrt(xx);
84+
yy = sqrt(yy);
85+
x_norm_[x_offset] = xx;
86+
y_norm_[0] = yy;
87+
z_[x_offset] = xy / (xx * yy);
4588
}
46-
x_norm[i] = sqrt(xx);
47-
y_norm[i] = sqrt(yy);
48-
49-
out[i] = xy / (x_norm[i] * y_norm[i]);
5089
}
51-
}
90+
91+
T* x_norm_;
92+
T* y_norm_;
93+
const T* x_;
94+
const T* y_;
95+
T* z_;
96+
const size_t cols_;
97+
};
5298

5399
template <typename DeviceContext, typename T>
54100
class CosSimKernel : public framework::OpKernel<T> {
@@ -68,58 +114,140 @@ class CosSimKernel : public framework::OpKernel<T> {
68114
int rows_y = in_y->dims()[0];
69115

70116
int cols = framework::product(in_x->dims()) / rows_x;
71-
auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
72-
cols, rows_x, cols);
73-
auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
74-
cols, rows_x, cols);
75-
76-
Function_forward(out_z->data<T>(), out_x_norm->data<T>(),
77-
out_y_norm->data<T>(), x_iter, y_iter, rows_x, cols);
78-
//
79-
// // convert Tensor to Eigen Tensor
80-
//// int rows_x = in_x->dims()[0];
81-
//// int rows_y = in_y->dims()[0];
82-
// auto x = EigenMatrix<T>::Reshape(*in_x, 1);
83-
// auto y = EigenMatrix<T>::Reshape(*in_y, 1);
84-
// auto z = EigenVector<T>::Flatten(*out_z);
85-
// auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
86-
// auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
87-
//
88-
// // compute
89-
// auto& place =
90-
// *context.template device_context<DeviceContext>().eigen_device();
91-
// auto row_along = Eigen::array<int, 1>({{1}});
92-
// x_norm.device(place) = x.square().sum(row_along).sqrt();
93-
// y_norm.device(place) = y.square().sum(row_along).sqrt();
94-
// if (rows_x == rows_y) {
95-
// auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
96-
// z.device(place) = xy / x_norm / y_norm;
97-
// } else {
98-
// Eigen::DSizes<int, 2> bcast(rows_x, 1);
99-
// auto xy = (x * y.broadcast(bcast)).sum(row_along);
100-
// z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
101-
// }
117+
118+
if (rows_x == rows_y) {
119+
CosSimFunctor<T, true> functor(
120+
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
121+
out_y_norm->data<T>(), out_z->data<T>(), cols);
122+
ForEachZip(out_x_norm->data<T>(), out_x_norm->data<T>() + rows_x,
123+
out_y_norm->data<T>(), functor);
124+
} else {
125+
CosSimFunctor<T, false> functor(
126+
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
127+
out_y_norm->data<T>(), out_z->data<T>(), cols);
128+
ForEachZip(out_x_norm->data<T>(), out_x_norm->data<T>() + rows_x,
129+
out_y_norm->data<T>(), functor);
130+
}
102131
}
103132
};
104133

105-
template <typename T, typename DeviceContext>
106-
void Function_element(T* result, ElementIterator<T, DeviceContext> dz,
107-
ElementIterator<T, DeviceContext> y,
108-
ElementIterator<T, DeviceContext> x_norm,
109-
ElementIterator<T, DeviceContext> y_norm,
110-
ElementIterator<T, DeviceContext> z,
111-
ElementIterator<T, DeviceContext> x, int num, int block) {
112-
for (int i = 0; i < num; ++i) {
113-
result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) -
114-
(*z) * (*x) / ((*x_norm) * (*x_norm)));
115-
++dz;
116-
++y;
117-
++x_norm;
118-
++y_norm;
119-
++z;
120-
++x;
134+
template <typename T>
135+
struct CosSimGradFunctor {
136+
CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
137+
const T* z, const T* dz, T* dx, int cols)
138+
: x_norm_(x_norm),
139+
y_norm_(y_norm),
140+
x_(x),
141+
y_(y),
142+
z_(z),
143+
dz_(dz),
144+
dx_(dx),
145+
cols_(static_cast<size_t>(cols)) {}
146+
147+
void operator()(const T& x_norm, const T& y_norm) const {
148+
size_t x_offset = &x_norm - x_norm_;
149+
size_t y_offset = &y_norm - y_norm_;
150+
151+
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
152+
// auto y_norm_square = y_norm_[y_offset] * y_norm_[y_offset];
153+
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[y_offset];
154+
auto dz = dz_[x_offset];
155+
156+
auto* dx = dx_ + cols_ * x_offset;
157+
auto* x = x_ + cols_ * x_offset;
158+
auto* y = y_ + cols_ * y_offset;
159+
auto z = z_[x_offset];
160+
161+
for (size_t i = 0; i < cols_; ++i) {
162+
dx[i] = dz * (y[i] / xy_norm_prod - z * x[i] / x_norm_square);
163+
}
121164
}
122-
}
165+
166+
const T* x_norm_;
167+
const T* y_norm_;
168+
const T* x_;
169+
const T* y_;
170+
const T* z_;
171+
const T* dz_;
172+
T* dx_;
173+
const size_t cols_;
174+
};
175+
176+
template <typename T>
177+
struct CosSimDxFunctor {
178+
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
179+
const T* z, const T* dz, T* dx, int cols)
180+
: x_norm_(x_norm),
181+
y_norm_(y_norm),
182+
x_(x),
183+
y_(y),
184+
z_(z),
185+
dz_(dz),
186+
dx_(dx),
187+
cols_(static_cast<size_t>(cols)) {}
188+
189+
void operator()(const T& x_norm, const T& y_norm) const {
190+
size_t x_offset = &x_norm - x_norm_;
191+
192+
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
193+
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0];
194+
auto dz = dz_[x_offset];
195+
auto z = z_[x_offset];
196+
197+
auto* dx = dx_ + cols_ * x_offset;
198+
auto* x = x_ + cols_ * x_offset;
199+
200+
for (size_t i = 0; i < cols_; ++i) {
201+
dx[i] = dz * (y_[i] / xy_norm_prod - z * x[i] / x_norm_square);
202+
}
203+
}
204+
205+
const T* x_norm_;
206+
const T* y_norm_;
207+
const T* x_;
208+
const T* y_;
209+
const T* z_;
210+
const T* dz_;
211+
T* dx_;
212+
const size_t cols_;
213+
};
214+
215+
template <typename T>
216+
struct CosSimDyFunctor {
217+
CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
218+
const T* z, const T* dz, T* dy, int cols)
219+
: x_norm_(x_norm),
220+
y_norm_(y_norm),
221+
x_(x),
222+
y_(y),
223+
z_(z),
224+
dz_(dz),
225+
dy_(dy),
226+
cols_(static_cast<size_t>(cols)) {}
227+
228+
void operator()(const T& x_norm, const T& y_norm) const {
229+
size_t x_offset = &x_norm - x_norm_;
230+
231+
auto y_norm_square = y_norm_[0] * y_norm_[0];
232+
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0];
233+
auto dz = dz_[x_offset];
234+
auto z = z_[x_offset];
235+
auto* x = x_ + cols_ * x_offset;
236+
237+
for (size_t i = 0; i < cols_; ++i) {
238+
dy_[i] += dz * (x[i] / xy_norm_prod - z * y_[i] / y_norm_square);
239+
}
240+
}
241+
242+
const T* x_norm_;
243+
const T* y_norm_;
244+
const T* x_;
245+
const T* y_;
246+
const T* z_;
247+
const T* dz_;
248+
T* dy_;
249+
const size_t cols_;
250+
};
123251

124252
template <typename DeviceContext, typename T>
125253
class CosSimGradKernel : public framework::OpKernel<T> {
@@ -140,45 +268,40 @@ class CosSimGradKernel : public framework::OpKernel<T> {
140268
int rows_y = in_y->dims()[0];
141269
int cols = framework::product(in_x->dims()) / rows_x;
142270

143-
//////////////////////////////
144-
// ##
145-
auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
146-
cols, rows_x, cols);
147-
auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
148-
cols, rows_x, cols);
149-
auto z_iter = ElementIterator<T, DeviceContext>(in_z->data<T>(), rows_x, 1,
150-
rows_x, cols);
151-
auto dz_iter = ElementIterator<T, DeviceContext>(in_grad_z->data<T>(),
152-
rows_x, 1, rows_x, cols);
153-
auto x_norm_iter = ElementIterator<T, DeviceContext>(
154-
in_x_norm->data<T>(), rows_x, 1, rows_x, cols);
155-
auto y_norm_iter = ElementIterator<T, DeviceContext>(
156-
in_y_norm->data<T>(), rows_y, 1, rows_x, cols);
157-
// ##
158-
//////////////////////////////
159-
// compute dx
160-
if (out_grad_x) {
161-
out_grad_x->mutable_data<T>(context.GetPlace());
162-
163-
//////////////////////////////
164-
// ##
165-
Function_element(out_grad_x->data<T>(), dz_iter, y_iter, x_norm_iter,
166-
y_norm_iter, z_iter, x_iter, rows_x * cols,
167-
rows_x * cols);
168-
// ##
169-
//////////////////////////////
170-
}
171-
// compute dy
172-
if (out_grad_y) {
173-
out_grad_y->mutable_data<T>(context.GetPlace());
174-
175-
//////////////////////////////
176-
// ##
177-
Function_element(out_grad_y->data<T>(), dz_iter, x_iter, y_norm_iter,
178-
x_norm_iter, z_iter, y_iter, rows_x * cols,
179-
rows_y * cols);
180-
// ##
181-
//////////////////////////////
271+
if (rows_x == rows_y) {
272+
if (out_grad_x) {
273+
CosSimGradFunctor<T> functor(
274+
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
275+
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
276+
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
277+
ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
278+
in_y_norm->data<T>(), functor);
279+
}
280+
if (out_grad_y) {
281+
CosSimGradFunctor<T> functor(
282+
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
283+
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
284+
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
285+
ForEachZip(in_y_norm->data<T>(), in_y_norm->data<T>() + rows_x,
286+
in_x_norm->data<T>(), functor);
287+
}
288+
} else {
289+
if (out_grad_x) {
290+
CosSimDxFunctor<T> functor(
291+
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
292+
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
293+
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
294+
ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
295+
in_y_norm->data<T>(), functor);
296+
}
297+
if (out_grad_y) {
298+
CosSimDyFunctor<T> functor(
299+
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
300+
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
301+
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
302+
ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
303+
in_y_norm->data<T>(), functor);
304+
}
182305
}
183306
}
184307
};

0 commit comments

Comments
 (0)