@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include " paddle/fluid/operators/fused/fused_seqpool_cvm_op.h"
16-
16+ # include < string >
1717namespace paddle {
1818namespace operators {
1919
@@ -30,12 +30,12 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
3030 PADDLE_ENFORCE_EQ (
3131 cvm_dims.size (), 2UL ,
3232 platform::errors::InvalidArgument (" Input(CVM)'s rank should be 2." ));
33- PADDLE_ENFORCE_EQ (
34- cvm_dims[1 ], 2UL ,
35- platform::errors::InvalidArgument (" The 2nd dimension of "
36- " Input(CVM) should be 2." ));
33+ PADDLE_ENFORCE_EQ (cvm_dims[1 ], 2UL , platform::errors::InvalidArgument (
34+ " The 2nd dimension of "
35+ " Input(CVM) should be 2." ));
3736
3837 auto ins_dims = ctx->GetInputsDim (" X" );
38+ const int cvm_offset = ctx->Attrs ().Get <int >(" cvm_offset" );
3939 const size_t num_inputs = ins_dims.size ();
4040 std::vector<framework::DDim> outs_dims;
4141 outs_dims.resize (num_inputs);
@@ -69,7 +69,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
6969 if (ctx->Attrs ().Get <bool >(" use_cvm" )) {
7070 out_dim = {-1 , dims[rank - 1 ]};
7171 } else {
72- out_dim = {-1 , dims[rank - 1 ] - 2 };
72+ out_dim = {-1 , dims[rank - 1 ] - cvm_offset };
7373 }
7474 outs_dims[i] = framework::make_ddim (out_dim);
7575 }
@@ -111,6 +111,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker {
111111 AddAttr<float >(" show_coeff" , " (float, default 0.2)" ).SetDefault (0.2 );
112112 AddAttr<float >(" clk_coeff" , " (float, default 1)" ).SetDefault (1 );
113113 AddAttr<float >(" threshold" , " (float, default 0.96)" ).SetDefault (0.96 );
114+ AddAttr<int >(" cvm_offset" , " (int, default 2)" ).SetDefault (2 );
114115
115116 AddComment (R"DOC(
116117Fuse multiple pairs of Sequence Pool and CVM Operator.
@@ -127,6 +128,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
127128 auto og_dims = ctx->GetInputsDim (framework::GradVarName (" Out" ));
128129 auto x_dims = ctx->GetInputsDim (" X" );
129130 auto cvm_dims = ctx->GetInputDim (" CVM" );
131+ const int cvm_offset = ctx->Attrs ().Get <int >(" cvm_offset" );
130132
131133 PADDLE_ENFORCE_EQ (
132134 cvm_dims.size (), 2 ,
@@ -151,7 +153,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
151153 } else {
152154 PADDLE_ENFORCE_EQ (
153155 og_dims[i][og_dims[i].size () - 1 ],
154- x_dims[i][og_dims[i].size () - 1 ] - 2 ,
156+ x_dims[i][og_dims[i].size () - 1 ] - cvm_offset ,
155157 platform::errors::InvalidArgument (
156158 " The dimension mismatch between Input(OUT@GRAD) and "
157159 " Input(X). Received Input(OUT@GRAD): input rank %u, "
0 commit comments