Skip to content

Commit f58fe6d

Browse files
author
chengduo
authored
Merge pull request #6601 from chengduoZH/profiling/cosine_op
Refine cos-sim-op
2 parents 0bd7f97 + 812c5f6 commit f58fe6d

6 files changed

Lines changed: 332 additions & 66 deletions

File tree

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
229229
op_library(conv_transpose_op DEPS vol2col)
230230
op_library(gru_op DEPS sequence2batch gru_compute)
231231
op_library(recurrent_op DEPS executor)
232+
op_library(cos_sim_op DEPS cos_sim_functor)
232233
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
233234
op_library(save_op DEPS lod_tensor)
234235
op_library(load_op DEPS lod_tensor)

paddle/operators/cos_sim_op.h

Lines changed: 51 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,15 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16-
#include "paddle/framework/eigen.h"
1716
#include "paddle/framework/op_registry.h"
17+
#include "paddle/operators/math/cos_sim_functor.h"
18+
#include "paddle/operators/math/math_function.h"
19+
#include "paddle/platform/for_range.h"
1820

1921
namespace paddle {
2022
namespace operators {
2123

2224
using Tensor = framework::Tensor;
23-
template <typename T, int MajorType = Eigen::RowMajor,
24-
typename IndexType = Eigen::DenseIndex>
25-
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
26-
template <typename T, int MajorType = Eigen::RowMajor,
27-
typename IndexType = Eigen::DenseIndex>
28-
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2925

3026
template <typename DeviceContext, typename T>
3127
class CosSimKernel : public framework::OpKernel<T> {
@@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel<T> {
4137
out_x_norm->mutable_data<T>(context.GetPlace());
4238
out_y_norm->mutable_data<T>(context.GetPlace());
4339

44-
// convert Tensor to Eigen Tensor
4540
int rows_x = in_x->dims()[0];
4641
int rows_y = in_y->dims()[0];
47-
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
48-
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
49-
auto z = EigenVector<T>::Flatten(*out_z);
50-
auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
51-
auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
5242

53-
// compute
54-
auto& place =
55-
*context.template device_context<DeviceContext>().eigen_device();
56-
auto row_along = Eigen::array<int, 1>({{1}});
57-
x_norm.device(place) = x.square().sum(row_along).sqrt();
58-
y_norm.device(place) = y.square().sum(row_along).sqrt();
43+
int cols = framework::product(in_x->dims()) / rows_x;
44+
5945
if (rows_x == rows_y) {
60-
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
61-
z.device(place) = xy / x_norm / y_norm;
46+
math::CosSimFunctor<T, true> functor(
47+
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
48+
out_y_norm->data<T>(), out_z->data<T>(), cols);
49+
platform::ForRange<DeviceContext> for_range(
50+
static_cast<const DeviceContext&>(context.device_context()), rows_x);
51+
for_range(functor);
6252
} else {
63-
Eigen::DSizes<int, 2> bcast(rows_x, 1);
64-
auto xy = (x * y.broadcast(bcast)).sum(row_along);
65-
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
53+
math::CosSimFunctor<T, false> functor(
54+
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
55+
out_y_norm->data<T>(), out_z->data<T>(), cols);
56+
platform::ForRange<DeviceContext> for_range(
57+
static_cast<const DeviceContext&>(context.device_context()), rows_x);
58+
for_range(functor);
6659
}
6760
}
6861
};
@@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel<T> {
8174
auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
8275
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
8376

84-
// convert Tensor to Eigen Tensor
85-
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
86-
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
87-
auto z = EigenMatrix<T>::Reshape(*in_z, 1);
88-
auto x_norm = EigenMatrix<T>::Reshape(*in_x_norm, 1);
89-
auto y_norm = EigenMatrix<T>::Reshape(*in_y_norm, 1);
90-
auto dz = EigenMatrix<T>::Reshape(*in_grad_z, 1);
91-
9277
// compute gradident
9378
int rows_x = in_x->dims()[0];
9479
int rows_y = in_y->dims()[0];
9580
int cols = framework::product(in_x->dims()) / rows_x;
96-
Eigen::DSizes<int, 2> bcast_cols(1, cols);
97-
auto z_bcast = z.broadcast(bcast_cols);
98-
auto dz_bcast = dz.broadcast(bcast_cols);
99-
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
100-
auto& place =
101-
*context.template device_context<DeviceContext>().eigen_device();
81+
10282
if (rows_x == rows_y) {
103-
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
104-
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
105-
// compute dx
10683
if (out_grad_x) {
107-
out_grad_x->mutable_data<T>(context.GetPlace());
108-
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
109-
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
110-
dx.device(place) = dz_bcast * grad;
84+
math::CosSimGradFunctor<T> functor(
85+
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
86+
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
87+
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
88+
platform::ForRange<DeviceContext> for_range(
89+
static_cast<const DeviceContext&>(context.device_context()),
90+
rows_x);
91+
for_range(functor);
11192
}
112-
// compute dy
11393
if (out_grad_y) {
114-
out_grad_y->mutable_data<T>(context.GetPlace());
115-
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
116-
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
117-
dy.device(place) = dz_bcast * grad;
94+
math::CosSimGradFunctor<T> functor(
95+
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
96+
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
97+
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
98+
platform::ForRange<DeviceContext> for_range(
99+
static_cast<const DeviceContext&>(context.device_context()),
100+
rows_x);
101+
for_range(functor);
118102
}
119103
} else {
120-
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1);
121-
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, cols);
122-
auto y_bcast = y.broadcast(bcast_rows);
123-
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols);
124-
auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows))
125-
.eval()
126-
.broadcast(bcast_cols);
127-
// compute dx
128104
if (out_grad_x) {
129-
out_grad_x->mutable_data<T>(context.GetPlace());
130-
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
131-
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
132-
dx.device(place) = dz_bcast * grad;
105+
math::CosSimDxFunctor<T> functor(
106+
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
107+
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
108+
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
109+
platform::ForRange<DeviceContext> for_range(
110+
static_cast<const DeviceContext&>(context.device_context()),
111+
rows_x);
112+
for_range(functor);
133113
}
134-
// compute dy
135114
if (out_grad_y) {
136115
out_grad_y->mutable_data<T>(context.GetPlace());
137-
auto dy = EigenVector<T>::Flatten(*out_grad_y);
138-
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
139-
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
116+
math::SetConstant<DeviceContext, T> set_zero;
117+
auto& dev_ctx = context.template device_context<DeviceContext>();
118+
set_zero(dev_ctx, out_grad_y, static_cast<T>(0));
119+
120+
math::CosSimDyFunctor<DeviceContext, T> functor;
121+
functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
122+
in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
123+
in_grad_z->data<T>(), static_cast<size_t>(rows_x),
124+
static_cast<size_t>(cols), out_grad_y->data<T>());
140125
}
141126
}
142127
}

paddle/operators/math/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if(WITH_GPU)
1616
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
1717
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
1818
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
19+
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
1920
else()
2021
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
2122
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
@@ -30,6 +31,7 @@ else()
3031
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
3132
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
3233
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
34+
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
3335
endif()
3436

3537
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/cos_sim_functor.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
namespace math {
20+
21+
template <typename T>
22+
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
23+
void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
24+
const T* y_norm, const T* x, const T* y, const T* z,
25+
const T* dz, const size_t rows, const size_t cols,
26+
T* dy) const {
27+
for (size_t row_id = 0; row_id < rows; ++row_id) {
28+
auto xy_norm_prod = x_norm[row_id] * y_norm[0];
29+
auto dz_data = dz[row_id];
30+
auto z_data = z[row_id];
31+
auto* x_data = x + cols * row_id;
32+
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
33+
34+
auto y_norm_square = y_norm[0] * y_norm[0];
35+
auto reciprocal_y_norm_square = 1 / y_norm_square;
36+
for (size_t i = 0; i < cols; ++i) {
37+
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
38+
z_data * y[i] * reciprocal_y_norm_square);
39+
}
40+
}
41+
}
42+
};
43+
44+
template class CosSimDyFunctor<platform::CPUDeviceContext, float>;
45+
template class CosSimDyFunctor<platform::CPUDeviceContext, double>;
46+
} // namespace math
47+
} // namespace operators
48+
} // namespace paddle
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/cos_sim_functor.h"
16+
#include "paddle/platform/cuda_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
template <typename T>
23+
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
24+
const T* y, const T* z, const T* dz,
25+
const size_t rows, const size_t cols, T* dy) {
26+
int grid_size = blockDim.x * gridDim.x;
27+
T y_norm_data = y_norm[0];
28+
for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
29+
row_id += grid_size) {
30+
T xy_norm_prod = x_norm[row_id] * y_norm_data;
31+
T dz_data = dz[row_id];
32+
T z_data = z[row_id];
33+
const T* x_data = x + cols * row_id;
34+
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
35+
36+
T y_norm_square = y_norm_data * y_norm_data;
37+
T reciprocal_y_norm_square = 1 / y_norm_square;
38+
for (size_t i = 0; i < cols; ++i) {
39+
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
40+
z_data * y[i] * reciprocal_y_norm_square);
41+
platform::CudaAtomicAdd(dy + i, dy_data);
42+
}
43+
}
44+
}
45+
46+
template <typename T>
47+
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
48+
void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm,
49+
const T* y_norm, const T* x, const T* y, const T* z,
50+
const T* dz, const size_t rows, const size_t cols,
51+
T* dy) const {
52+
const int block_size = 512;
53+
dim3 threads(block_size, 1);
54+
dim3 grid(1, (rows + block_size - 1) / block_size);
55+
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
56+
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
57+
}
58+
};
59+
60+
template class CosSimDyFunctor<platform::CUDADeviceContext, float>;
61+
template class CosSimDyFunctor<platform::CUDADeviceContext, double>;
62+
} // namespace math
63+
} // namespace operators
64+
} // namespace paddle

0 commit comments

Comments
 (0)