@@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel {
3333 void InferShape (framework::InferShapeContext* ctx) const override {
3434 OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " GRU" );
3535 OP_INOUT_CHECK (ctx->HasInput (" Weight" ), " Input" , " Weight" , " GRU" );
36- OP_INOUT_CHECK (ctx->HasOutput (" BatchGate" ), " Output" , " BatchGate" , " GRU" );
37- OP_INOUT_CHECK (ctx->HasOutput (" BatchResetHiddenPrev" ), " Output" ,
38- " BatchResetHiddenPrev" , " GRU" );
39- OP_INOUT_CHECK (ctx->HasOutput (" BatchHidden" ), " Output" , " BatchHidden" ,
40- " GRU" );
4136 OP_INOUT_CHECK (ctx->HasOutput (" Hidden" ), " Output" , " Hidden" , " GRU" );
42-
37+ bool is_test = ctx->Attrs ().Get <bool >(" is_test" );
38+ if (!is_test) {
39+ OP_INOUT_CHECK (ctx->HasOutput (" BatchGate" ), " Output" , " BatchGate" , " GRU" );
40+ OP_INOUT_CHECK (ctx->HasOutput (" BatchResetHiddenPrev" ), " Output" ,
41+ " BatchResetHiddenPrev" , " GRU" );
42+ OP_INOUT_CHECK (ctx->HasOutput (" BatchHidden" ), " Output" , " BatchHidden" ,
43+ " GRU" );
44+ }
4345 auto input_dims = ctx->GetInputDim (" Input" );
4446 auto weight_dims = ctx->GetInputDim (" Weight" );
4547 int input_size = input_dims[1 ];
@@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel {
8486 " [%d, %d] (Bias) vs [1, %d] (frame_size * 3)." ,
8587 bias_height, bias_width, frame_size * 3 ));
8688 }
87- ctx->SetOutputDim (" BatchGate" , input_dims);
88- ctx->SetOutputDim (" BatchResetHiddenPrev" , {input_dims[0 ], frame_size});
89- ctx->SetOutputDim (" BatchHidden" , {input_dims[0 ], frame_size});
89+ if (!is_test) {
90+ ctx->SetOutputDim (" BatchGate" , input_dims);
91+ ctx->SetOutputDim (" BatchResetHiddenPrev" , {input_dims[0 ], frame_size});
92+ ctx->SetOutputDim (" BatchHidden" , {input_dims[0 ], frame_size});
93+ }
9094 ctx->SetOutputDim (" Hidden" , {input_dims[0 ], frame_size});
9195 ctx->ShareLoD (" Input" , " Hidden" );
9296 }
@@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
124128 " organized in batches. The LoD size is 2. The first LoD contains "
125129 " the batch offsets and the second LoD contains the indexes in "
126130 " the raw sequence data." )
127- .AsIntermediate ();
131+ .AsIntermediate ()
132+ .AsExtra ();
128133 AddOutput (
129134 " BatchResetHiddenPrev" ,
130135 " (LoDTensor) The reset hidden state LoDTensor organized in batches. "
131136 " This LoDTensor is a matrix with shape (T X D) and has the same LoD "
132137 " with `BatchGate`." )
133- .AsIntermediate ();
138+ .AsIntermediate ()
139+ .AsExtra ();
134140 AddOutput (
135141 " BatchHidden" ,
136142 " (LoDTensor) The hidden state LoDTensor organized in batches. "
137143 " This LoDTensor is a matrix with shape (T X D) and has the same LoD "
138144 " with `BatchGate`." )
139- .AsIntermediate ();
145+ .AsIntermediate ()
146+ .AsExtra ();
140147 AddOutput (
141148 " Hidden" ,
142149 " (LoDTensor) the hidden state LoDTensor organized in sequences. "
@@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
155162 " (bool, default: False) "
156163 " whether to compute reversed GRU." )
157164 .SetDefault (false );
165+ AddAttr<bool >(" is_test" , " True if in test phase." )
166+ .SetDefault (false )
167+ .AsExtra ();
158168 AddAttr<bool >(" origin_mode" ,
159169 " bool"
160170 " use origin mode in article https://arxiv.org/abs/1412.3555" )
@@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel<T> {
269279 public:
270280 void BatchCompute (const framework::ExecutionContext& context) const {
271281 using DeviceContext = paddle::platform::CPUDeviceContext;
282+ using LodTensorPtr = LoDTensor*;
283+ bool is_test = context.Attr <bool >(" is_test" );
284+
272285 bool origin_mode = context.Attr <bool >(" origin_mode" );
273286 auto * input = context.Input <LoDTensor>(" Input" );
274287 auto * h0 = context.Input <Tensor>(" H0" );
275288 auto * weight = context.Input <Tensor>(" Weight" );
276289 const T* weight_data = weight->data <T>();
277290 auto * bias = context.Input <Tensor>(" Bias" );
278- auto * batch_gate = context.Output <LoDTensor>(" BatchGate" );
279- batch_gate->mutable_data <T>(context.GetPlace ());
280- auto * batch_reset_hidden_prev =
281- context.Output <LoDTensor>(" BatchResetHiddenPrev" );
282- batch_reset_hidden_prev->mutable_data <T>(context.GetPlace ());
283- auto * batch_hidden = context.Output <LoDTensor>(" BatchHidden" );
284- batch_hidden->mutable_data <T>(context.GetPlace ());
285291 auto * hidden = context.Output <LoDTensor>(" Hidden" );
286292 hidden->mutable_data <T>(context.GetPlace ());
287293
294+ auto input_dims = input->dims ();
288295 auto hidden_dims = hidden->dims ();
289296
297+ LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
298+ LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
299+ if (is_test) {
300+ batch_gate = &batch_gate_tmp;
301+ batch_gate->Resize (input_dims);
302+
303+ batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
304+ batch_reset_hidden_prev->Resize (hidden_dims);
305+
306+ batch_hidden = &batch_hidden_tmp;
307+ batch_hidden->Resize (hidden_dims);
308+ } else {
309+ batch_gate = context.Output <LoDTensor>(" BatchGate" );
310+ batch_hidden = context.Output <LoDTensor>(" BatchHidden" );
311+ batch_reset_hidden_prev =
312+ context.Output <LoDTensor>(" BatchResetHiddenPrev" );
313+ }
314+ batch_gate->mutable_data <T>(context.GetPlace ());
315+ batch_reset_hidden_prev->mutable_data <T>(context.GetPlace ());
316+ batch_hidden->mutable_data <T>(context.GetPlace ());
317+
290318 bool is_reverse = context.Attr <bool >(" is_reverse" );
291319 math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
292320 auto & dev_ctx = context.template device_context <DeviceContext>();
0 commit comments