Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions paddle/operators/cos_sim_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
}
};

template <typename T>
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
inline void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t offset = 0; offset < rows; ++offset) {
auto xy_norm_prod = x_norm[offset] * y_norm[0];
auto dz_data = dz[offset];
auto z_data = z[offset];
auto* x_data = x + cols * offset;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;

auto y_norm_square = y_norm[0] * y_norm[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
}
}
}
};
} // namespace operators
} // namespace paddle

Expand Down
45 changes: 45 additions & 0 deletions paddle/operators/cos_sim_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,51 @@ limitations under the License. */

#define EIGEN_USE_GPU
#include "paddle/operators/cos_sim_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int offset = blockIdx.x * blockDim.x + threadIdx.x; offset < rows;
offset += grid_size) {
T xy_norm_prod = x_norm[offset] * y_norm_data;
T dz_data = dz[offset];
T z_data = z[offset];
const T* x_data = x + cols * offset;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;

T y_norm_square = y_norm_data * y_norm_data;
T reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
platform::CudaAtomicAdd(dy + i, dy_data);
}
}
}

template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
inline void operator()(const platform::CUDADeviceContext& ctx,
const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz, const size_t rows,
const size_t cols, T* dy) const {
const int block_size = 512;
dim3 threads(block_size, 1);
dim3 grid(1, (rows + block_size - 1) / block_size);
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
Expand Down
258 changes: 191 additions & 67 deletions paddle/operators/cos_sim_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,67 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/for_range.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename T, bool same_row>
struct CosSimFunctor {
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)) {}

inline HOSTDEVICE void operator()(size_t offset) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change offset => i and below for to j may be more clear? Or row_id and col_id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I have replace offset with row_id.

auto* x = x_ + cols_ * offset;
T xx = 0, xy = 0, yy = 0;
if (same_row) {
auto* y = y_ + cols_ * offset;
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[offset] = yy;
x_norm_[offset] = xx;
z_[offset] = xy / (xx * yy);
} else { // This can be wrote in a better way.
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y_[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
if (offset == 0) y_norm_[0] = yy;
x_norm_[offset] = xx;
z_[offset] = xy / (xx * yy);
}
}

T* x_norm_;
T* y_norm_;
const T* x_;
const T* y_;
T* z_;
const size_t cols_;
};

template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> {
Expand All @@ -41,30 +89,114 @@ class CosSimKernel : public framework::OpKernel<T> {
out_x_norm->mutable_data<T>(context.GetPlace());
out_y_norm->mutable_data<T>(context.GetPlace());

// convert Tensor to Eigen Tensor
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenVector<T>::Flatten(*out_z);
auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
auto y_norm = EigenVector<T>::Flatten(*out_y_norm);

// compute
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();

int cols = framework::product(in_x->dims()) / rows_x;

if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
z.device(place) = xy / x_norm / y_norm;
CosSimFunctor<T, true> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
auto xy = (x * y.broadcast(bcast)).sum(row_along);
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
CosSimFunctor<T, false> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
}
}
};

template <typename T>
struct CosSimGradFunctor {
CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}

inline HOSTDEVICE void operator()(size_t offset) const {
auto x_norm_square = x_norm_[offset] * x_norm_[offset];
auto xy_norm_prod = x_norm_[offset] * y_norm_[offset];
auto dz = dz_[offset];
auto z = z_[offset];

auto* dx = dx_ + cols_ * offset;
auto* x = x_ + cols_ * offset;
auto* y = y_ + cols_ * offset;

auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}

const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};

template <typename T>
struct CosSimDxFunctor {
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}

inline HOSTDEVICE void operator()(size_t offset) const {
auto xy_norm_prod = x_norm_[offset] * y_norm_[0];
auto dz = dz_[offset];
auto z = z_[offset];
auto* x = x_ + cols_ * offset;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto x_norm_square = x_norm_[offset] * x_norm_[offset];
auto* dx = dx_ + cols_ * offset;
auto reciprocal_x_norm_square = 1 / x_norm_square;

for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};

template <typename DeviceContext, typename T>
struct CosSimDyFunctor {
inline void operator()(const DeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const;
};

template <typename DeviceContext, typename T>
Expand All @@ -81,62 +213,54 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));

// convert Tensor to Eigen Tensor
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::Reshape(*in_z, 1);
auto x_norm = EigenMatrix<T>::Reshape(*in_x_norm, 1);
auto y_norm = EigenMatrix<T>::Reshape(*in_y_norm, 1);
auto dz = EigenMatrix<T>::Reshape(*in_grad_z, 1);

// compute gradident
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
Eigen::DSizes<int, 2> bcast_cols(1, cols);
auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();

if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
CosSimGradFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad;
CosSimGradFunctor<T> functor(
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
} else {
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1);
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, cols);
auto y_bcast = y.broadcast(bcast_rows);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols);
auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows))
.eval()
.broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
CosSimDxFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenVector<T>::Flatten(*out_grad_y);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, out_grad_y, static_cast<T>(0));

CosSimDyFunctor<DeviceContext, T> functor;
functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
in_grad_z->data<T>(), static_cast<size_t>(rows_x),
static_cast<size_t>(cols), out_grad_y->data<T>());
}
}
}
Expand Down