@@ -28,17 +28,21 @@ class CosSimKernel : public framework::OpKernel<T> {
2828 public:
2929 void Compute (const framework::ExecutionContext& context) const override {
3030 // get Tensor
31- auto * in_x = context.Input <Tensor >(" X" );
31+ auto * in_x = context.Input <framework::LoDTensor >(" X" );
3232 auto * in_y = context.Input <Tensor>(" Y" );
33- auto * out_z = context.Output <Tensor >(" Out" );
33+ auto * out_z = context.Output <framework::LoDTensor >(" Out" );
3434 auto * out_x_norm = context.Output <Tensor>(" XNorm" );
3535 auto * out_y_norm = context.Output <Tensor>(" YNorm" );
36- out_z->mutable_data <T>(context.GetPlace ());
37- out_x_norm->mutable_data <T>(context.GetPlace ());
38- out_y_norm->mutable_data <T>(context.GetPlace ());
3936
4037 int rows_x = in_x->dims ()[0 ];
4138 int rows_y = in_y->dims ()[0 ];
39+ out_z->Resize ({rows_x, 1 });
40+ out_x_norm->Resize ({rows_x, 1 });
41+ out_y_norm->Resize ({rows_y, 1 });
42+ out_z->mutable_data <T>(context.GetPlace ());
43+ out_x_norm->mutable_data <T>(context.GetPlace ());
44+ out_y_norm->mutable_data <T>(context.GetPlace ());
45+ out_z->set_lod (in_x->lod ());
4246
4347 int cols = framework::product (in_x->dims ()) / rows_x;
4448
@@ -81,6 +85,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
8185
8286 if (rows_x == rows_y) {
8387 if (out_grad_x) {
88+ out_grad_x->Resize (in_x->dims ());
8489 math::CosSimGradFunctor<T> functor (
8590 in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
8691 in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
@@ -91,6 +96,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
9196 for_range (functor);
9297 }
9398 if (out_grad_y) {
99+ out_grad_y->Resize (in_y->dims ());
94100 math::CosSimGradFunctor<T> functor (
95101 in_y_norm->data <T>(), in_x_norm->data <T>(), in_y->data <T>(),
96102 in_x->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
@@ -102,6 +108,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
102108 }
103109 } else {
104110 if (out_grad_x) {
111+ out_grad_x->Resize (in_x->dims ());
105112 math::CosSimDxFunctor<T> functor (
106113 in_x_norm->data <T>(), in_y_norm->data <T>(), in_x->data <T>(),
107114 in_y->data <T>(), in_z->data <T>(), in_grad_z->data <T>(),
@@ -112,6 +119,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
112119 for_range (functor);
113120 }
114121 if (out_grad_y) {
122+ out_grad_y->Resize (in_y->dims ());
115123 out_grad_y->mutable_data <T>(context.GetPlace ());
116124 math::SetConstant<DeviceContext, T> set_zero;
117125 auto & dev_ctx = context.template device_context <DeviceContext>();
0 commit comments