Skip to content

Commit 49df2a7

Browse files
committed
refine gradient function
1 parent bcf0b56 commit 49df2a7

File tree

1 file changed

+33
-68
lines changed

1 file changed

+33
-68
lines changed

paddle/operators/cos_sim_op.h

Lines changed: 33 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,16 @@
1313
limitations under the License. */
1414

1515
#pragma once
16-
#include "paddle/framework/eigen.h"
1716
#include "paddle/framework/op_registry.h"
1817
#include "paddle/operators/elementwise_op_function.h"
1918

2019
namespace paddle {
2120
namespace operators {
2221

2322
using Tensor = framework::Tensor;
24-
template <typename T, int MajorType = Eigen::RowMajor,
25-
typename IndexType = Eigen::DenseIndex>
26-
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
27-
template <typename T, int MajorType = Eigen::RowMajor,
28-
typename IndexType = Eigen::DenseIndex>
29-
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
3023

3124
template <typename IT1, typename IT2, typename Callback>
3225
static void ForEachZip(IT1 begin1, IT1 last1, IT2 begin2, Callback callback) {
33-
// This method could be implemented in CUDA
3426
for (; begin1 < last1; ++begin1, ++begin2) {
3527
callback(*begin1, *begin2);
3628
}
@@ -66,15 +58,8 @@ struct CosSimFunctor {
6658
x_norm_[x_offset] = xx;
6759
y_norm_[y_offset] = yy;
6860
z_[x_offset] = xy / (xx * yy);
69-
} else {
61+
} else { // This can be wrote in a better way.
7062
auto* y = y_;
71-
// if (yy == -1) {
72-
// yy = 0;
73-
// for (size_t i = 0; i < cols_; ++i) {
74-
// yy += y[i] * y[i];
75-
// }
76-
// y_norm[0] = sqrt(yy);
77-
// }
7863
for (size_t i = 0; i < cols_; ++i) {
7964
xx += x[i] * x[i];
8065
yy += y[i] * y[i]; // only need
@@ -144,22 +129,25 @@ struct CosSimGradFunctor {
144129
dx_(dx),
145130
cols_(static_cast<size_t>(cols)) {}
146131

147-
void operator()(const T& x_norm, const T& y_norm) const {
132+
inline void operator()(const T& x_norm, const T& y_norm) const {
148133
size_t x_offset = &x_norm - x_norm_;
149134
size_t y_offset = &y_norm - y_norm_;
150135

151136
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
152-
// auto y_norm_square = y_norm_[y_offset] * y_norm_[y_offset];
153137
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[y_offset];
154138
auto dz = dz_[x_offset];
139+
auto z = z_[x_offset];
155140

156141
auto* dx = dx_ + cols_ * x_offset;
157142
auto* x = x_ + cols_ * x_offset;
143+
158144
auto* y = y_ + cols_ * y_offset;
159-
auto z = z_[x_offset];
160145

146+
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
147+
auto reciprocal_x_norm_square = 1 / x_norm_square;
161148
for (size_t i = 0; i < cols_; ++i) {
162-
dx[i] = dz * (y[i] / xy_norm_prod - z * x[i] / x_norm_square);
149+
dx[i] = dz * (y[i] * reciprocal_xy_norm_prod -
150+
z * x[i] * reciprocal_x_norm_square);
163151
}
164152
}
165153

@@ -173,69 +161,45 @@ struct CosSimGradFunctor {
173161
const size_t cols_;
174162
};
175163

176-
template <typename T>
164+
template <typename T, bool Dx>
177165
struct CosSimDxFunctor {
178166
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
179-
const T* z, const T* dz, T* dx, int cols)
167+
const T* z, const T* dz, T* dx, T* dy, int cols)
180168
: x_norm_(x_norm),
181169
y_norm_(y_norm),
182170
x_(x),
183171
y_(y),
184172
z_(z),
185173
dz_(dz),
186174
dx_(dx),
187-
cols_(static_cast<size_t>(cols)) {}
188-
189-
void operator()(const T& x_norm, const T& y_norm) const {
190-
size_t x_offset = &x_norm - x_norm_;
191-
192-
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
193-
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0];
194-
auto dz = dz_[x_offset];
195-
auto z = z_[x_offset];
196-
197-
auto* dx = dx_ + cols_ * x_offset;
198-
auto* x = x_ + cols_ * x_offset;
199-
200-
for (size_t i = 0; i < cols_; ++i) {
201-
dx[i] = dz * (y_[i] / xy_norm_prod - z * x[i] / x_norm_square);
202-
}
203-
}
204-
205-
const T* x_norm_;
206-
const T* y_norm_;
207-
const T* x_;
208-
const T* y_;
209-
const T* z_;
210-
const T* dz_;
211-
T* dx_;
212-
const size_t cols_;
213-
};
214-
215-
template <typename T>
216-
struct CosSimDyFunctor {
217-
CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
218-
const T* z, const T* dz, T* dy, int cols)
219-
: x_norm_(x_norm),
220-
y_norm_(y_norm),
221-
x_(x),
222-
y_(y),
223-
z_(z),
224-
dz_(dz),
225175
dy_(dy),
226176
cols_(static_cast<size_t>(cols)) {}
227177

228-
void operator()(const T& x_norm, const T& y_norm) const {
178+
inline void operator()(const T& x_norm, const T& y_norm) const {
229179
size_t x_offset = &x_norm - x_norm_;
230180

231-
auto y_norm_square = y_norm_[0] * y_norm_[0];
232181
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0];
233182
auto dz = dz_[x_offset];
234183
auto z = z_[x_offset];
235184
auto* x = x_ + cols_ * x_offset;
185+
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
236186

237-
for (size_t i = 0; i < cols_; ++i) {
238-
dy_[i] += dz * (x[i] / xy_norm_prod - z * y_[i] / y_norm_square);
187+
if (Dx) {
188+
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
189+
auto* dx = dx_ + cols_ * x_offset;
190+
auto* x = x_ + cols_ * x_offset;
191+
auto reciprocal_x_norm_square = 1 / x_norm_square;
192+
for (size_t i = 0; i < cols_; ++i) {
193+
dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod -
194+
z * x[i] * reciprocal_x_norm_square);
195+
}
196+
} else {
197+
auto y_norm_square = y_norm_[0] * y_norm_[0];
198+
auto reciprocal_y_norm_square = 1 / y_norm_square;
199+
for (size_t i = 0; i < cols_; ++i) {
200+
dy_[i] += dz * (x[i] * reciprocal_xy_norm_prod -
201+
z * y_[i] * reciprocal_y_norm_square);
202+
}
239203
}
240204
}
241205

@@ -245,6 +209,7 @@ struct CosSimDyFunctor {
245209
const T* y_;
246210
const T* z_;
247211
const T* dz_;
212+
T* dx_;
248213
T* dy_;
249214
const size_t cols_;
250215
};
@@ -287,17 +252,17 @@ class CosSimGradKernel : public framework::OpKernel<T> {
287252
}
288253
} else {
289254
if (out_grad_x) {
290-
CosSimDxFunctor<T> functor(
255+
CosSimDxFunctor<T, true> functor(
291256
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
292257
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
293-
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
258+
out_grad_x->mutable_data<T>(context.GetPlace()), nullptr, cols);
294259
ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
295260
in_y_norm->data<T>(), functor);
296261
}
297262
if (out_grad_y) {
298-
CosSimDyFunctor<T> functor(
263+
CosSimDxFunctor<T, false> functor(
299264
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
300-
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
265+
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(), nullptr,
301266
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
302267
ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
303268
in_y_norm->data<T>(), functor);

0 commit comments

Comments
 (0)