Skip to content

Commit 9195c3b

Browse files
authored
Merge pull request #16280 from luotao1/cos_sim_infershape
refine cos_sim infershape
2 parents 6382b62 + c05af91 commit 9195c3b

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

paddle/fluid/operators/cos_sim_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
7474
"Norm of the second input, reduced along the 1st "
7575
"dimension.")
7676
.AsIntermediate();
77+
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
78+
"Skip calling InferShape() function in the runtime.")
79+
.SetDefault(true);
7780

7881
AddComment(R"DOC(
7982
**Cosine Similarity Operator**

paddle/fluid/operators/cos_sim_op.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,21 @@ class CosSimKernel : public framework::OpKernel<T> {
2828
public:
2929
void Compute(const framework::ExecutionContext& context) const override {
3030
// get Tensor
31-
auto* in_x = context.Input<Tensor>("X");
31+
auto* in_x = context.Input<framework::LoDTensor>("X");
3232
auto* in_y = context.Input<Tensor>("Y");
33-
auto* out_z = context.Output<Tensor>("Out");
33+
auto* out_z = context.Output<framework::LoDTensor>("Out");
3434
auto* out_x_norm = context.Output<Tensor>("XNorm");
3535
auto* out_y_norm = context.Output<Tensor>("YNorm");
36-
out_z->mutable_data<T>(context.GetPlace());
37-
out_x_norm->mutable_data<T>(context.GetPlace());
38-
out_y_norm->mutable_data<T>(context.GetPlace());
3936

4037
int rows_x = in_x->dims()[0];
4138
int rows_y = in_y->dims()[0];
39+
out_z->Resize({rows_x, 1});
40+
out_x_norm->Resize({rows_x, 1});
41+
out_y_norm->Resize({rows_y, 1});
42+
out_z->mutable_data<T>(context.GetPlace());
43+
out_x_norm->mutable_data<T>(context.GetPlace());
44+
out_y_norm->mutable_data<T>(context.GetPlace());
45+
out_z->set_lod(in_x->lod());
4246

4347
int cols = framework::product(in_x->dims()) / rows_x;
4448

@@ -81,6 +85,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
8185

8286
if (rows_x == rows_y) {
8387
if (out_grad_x) {
88+
out_grad_x->Resize(in_x->dims());
8489
math::CosSimGradFunctor<T> functor(
8590
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
8691
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
@@ -91,6 +96,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
9196
for_range(functor);
9297
}
9398
if (out_grad_y) {
99+
out_grad_y->Resize(in_y->dims());
94100
math::CosSimGradFunctor<T> functor(
95101
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
96102
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
@@ -102,6 +108,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
102108
}
103109
} else {
104110
if (out_grad_x) {
111+
out_grad_x->Resize(in_x->dims());
105112
math::CosSimDxFunctor<T> functor(
106113
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
107114
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
@@ -112,6 +119,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
112119
for_range(functor);
113120
}
114121
if (out_grad_y) {
122+
out_grad_y->Resize(in_y->dims());
115123
out_grad_y->mutable_data<T>(context.GetPlace());
116124
math::SetConstant<DeviceContext, T> set_zero;
117125
auto& dev_ctx = context.template device_context<DeviceContext>();

0 commit comments

Comments
 (0)