Skip to content

Commit 598f684

Browse files
committed
add AsExtra for gru
1 parent d3266dc commit 598f684

File tree

4 files changed

+95
-32
lines changed

4 files changed

+95
-32
lines changed

paddle/fluid/operators/gru_op.cc

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel {
3333
void InferShape(framework::InferShapeContext* ctx) const override {
3434
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
3535
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
36-
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
37-
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
38-
"BatchResetHiddenPrev", "GRU");
39-
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
40-
"GRU");
4136
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
42-
37+
bool is_test = ctx->Attrs().Get<bool>("is_test");
38+
if (!is_test) {
39+
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
40+
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
41+
"BatchResetHiddenPrev", "GRU");
42+
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
43+
"GRU");
44+
}
4345
auto input_dims = ctx->GetInputDim("Input");
4446
auto weight_dims = ctx->GetInputDim("Weight");
4547
int input_size = input_dims[1];
@@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel {
8486
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
8587
bias_height, bias_width, frame_size * 3));
8688
}
87-
ctx->SetOutputDim("BatchGate", input_dims);
88-
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
89-
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
89+
if (!is_test) {
90+
ctx->SetOutputDim("BatchGate", input_dims);
91+
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
92+
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
93+
}
9094
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
9195
ctx->ShareLoD("Input", "Hidden");
9296
}
@@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
124128
"organized in batches. The LoD size is 2. The first LoD contains "
125129
"the batch offsets and the second LoD contains the indexes in "
126130
"the raw sequence data.")
127-
.AsIntermediate();
131+
.AsIntermediate()
132+
.AsExtra();
128133
AddOutput(
129134
"BatchResetHiddenPrev",
130135
"(LoDTensor) The reset hidden state LoDTensor organized in batches. "
131136
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
132137
"with `BatchGate`.")
133-
.AsIntermediate();
138+
.AsIntermediate()
139+
.AsExtra();
134140
AddOutput(
135141
"BatchHidden",
136142
"(LoDTensor) The hidden state LoDTensor organized in batches. "
137143
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
138144
"with `BatchGate`.")
139-
.AsIntermediate();
145+
.AsIntermediate()
146+
.AsExtra();
140147
AddOutput(
141148
"Hidden",
142149
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
@@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
155162
"(bool, default: False) "
156163
"whether to compute reversed GRU.")
157164
.SetDefault(false);
165+
AddAttr<bool>("is_test", "True if in test phase.")
166+
.SetDefault(false)
167+
.AsExtra();
158168
AddAttr<bool>("origin_mode",
159169
"bool"
160170
"use origin mode in article https://arxiv.org/abs/1412.3555")
@@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel<T> {
269279
public:
270280
void BatchCompute(const framework::ExecutionContext& context) const {
271281
using DeviceContext = paddle::platform::CPUDeviceContext;
282+
using LodTensorPtr = LoDTensor*;
283+
bool is_test = context.Attr<bool>("is_test");
284+
272285
bool origin_mode = context.Attr<bool>("origin_mode");
273286
auto* input = context.Input<LoDTensor>("Input");
274287
auto* h0 = context.Input<Tensor>("H0");
275288
auto* weight = context.Input<Tensor>("Weight");
276289
const T* weight_data = weight->data<T>();
277290
auto* bias = context.Input<Tensor>("Bias");
278-
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
279-
batch_gate->mutable_data<T>(context.GetPlace());
280-
auto* batch_reset_hidden_prev =
281-
context.Output<LoDTensor>("BatchResetHiddenPrev");
282-
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
283-
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
284-
batch_hidden->mutable_data<T>(context.GetPlace());
285291
auto* hidden = context.Output<LoDTensor>("Hidden");
286292
hidden->mutable_data<T>(context.GetPlace());
287293

294+
auto input_dims = input->dims();
288295
auto hidden_dims = hidden->dims();
289296

297+
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
298+
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
299+
if (is_test) {
300+
batch_gate = &batch_gate_tmp;
301+
batch_gate->Resize(input_dims);
302+
303+
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
304+
batch_reset_hidden_prev->Resize(hidden_dims);
305+
306+
batch_hidden = &batch_hidden_tmp;
307+
batch_hidden->Resize(hidden_dims);
308+
} else {
309+
batch_gate = context.Output<LoDTensor>("BatchGate");
310+
batch_hidden = context.Output<LoDTensor>("BatchHidden");
311+
batch_reset_hidden_prev =
312+
context.Output<LoDTensor>("BatchResetHiddenPrev");
313+
}
314+
batch_gate->mutable_data<T>(context.GetPlace());
315+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
316+
batch_hidden->mutable_data<T>(context.GetPlace());
317+
290318
bool is_reverse = context.Attr<bool>("is_reverse");
291319
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
292320
auto& dev_ctx = context.template device_context<DeviceContext>();

paddle/fluid/operators/gru_op.cu.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,42 @@ template <typename DeviceContext, typename T>
2828
class GRUKernel : public framework::OpKernel<T> {
2929
public:
3030
void BatchCompute(const framework::ExecutionContext& context) const {
31+
using LodTensorPtr = LoDTensor*;
32+
33+
bool is_test = context.Attr<bool>("is_test");
3134
bool origin_mode = context.Attr<bool>("origin_mode");
3235
auto* input = context.Input<LoDTensor>("Input");
3336
auto* h0 = context.Input<Tensor>("H0");
3437
auto* weight = context.Input<Tensor>("Weight");
3538
const T* weight_data = weight->data<T>();
3639
auto* bias = context.Input<Tensor>("Bias");
37-
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
38-
batch_gate->mutable_data<T>(context.GetPlace());
39-
auto* batch_reset_hidden_prev =
40-
context.Output<LoDTensor>("BatchResetHiddenPrev");
41-
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
42-
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
43-
batch_hidden->mutable_data<T>(context.GetPlace());
4440
auto* hidden = context.Output<LoDTensor>("Hidden");
4541
hidden->mutable_data<T>(context.GetPlace());
4642

43+
auto input_dims = input->dims();
4744
auto hidden_dims = hidden->dims();
4845

46+
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
47+
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
48+
if (is_test) {
49+
batch_gate = &batch_gate_tmp;
50+
batch_gate->Resize(input_dims);
51+
52+
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
53+
batch_reset_hidden_prev->Resize(hidden_dims);
54+
55+
batch_hidden = &batch_hidden_tmp;
56+
batch_hidden->Resize(hidden_dims);
57+
} else {
58+
batch_gate = context.Output<LoDTensor>("BatchGate");
59+
batch_hidden = context.Output<LoDTensor>("BatchHidden");
60+
batch_reset_hidden_prev =
61+
context.Output<LoDTensor>("BatchResetHiddenPrev");
62+
}
63+
batch_gate->mutable_data<T>(context.GetPlace());
64+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
65+
batch_hidden->mutable_data<T>(context.GetPlace());
66+
4967
bool is_reverse = context.Attr<bool>("is_reverse");
5068
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
5169
auto& dev_ctx = context.template device_context<DeviceContext>();

python/paddle/fluid/tests/unittests/test_gru_op.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import math
2020
import functools
21-
from op_test import OpTest
21+
from op_test import OpTest, skip_check_grad_ci
2222
from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION
2323
from paddle import fluid
2424
from paddle.fluid import Program, program_guard
@@ -106,6 +106,9 @@ class TestGRUOp(OpTest):
106106
def set_confs(self):
107107
pass
108108

109+
def set_is_test(self):
110+
self.is_test = False
111+
109112
def setUp(self):
110113
self.op_type = "gru"
111114
self.lod = [[2, 4, 3]]
@@ -118,6 +121,7 @@ def setUp(self):
118121
self.dtype = 'float64'
119122
self.origin_mode = False
120123
self.set_confs()
124+
self.set_is_test()
121125

122126
T = sum(self.lod[0])
123127
N = len(self.lod[0])
@@ -153,7 +157,8 @@ def setUp(self):
153157
'activation': self.act_state,
154158
'gate_activation': self.act_gate,
155159
'is_reverse': self.is_reverse,
156-
'origin_mode': self.origin_mode
160+
'origin_mode': self.origin_mode,
161+
'is_test': self.is_test
157162
}
158163

159164
def test_check_output(self):
@@ -229,6 +234,21 @@ def set_confs(self):
229234
self.origin_mode = True
230235

231236

237+
class TestGRUOpInference(TestGRUOp):
238+
def set_is_test(self):
239+
self.is_test = True
240+
241+
def test_check_output(self):
242+
new_outputs = {}
243+
new_outputs['Hidden'] = self.outputs['Hidden']
244+
self.outputs = new_outputs
245+
super(TestGRUOpInference, self).test_check_output()
246+
247+
# avoid checking gradient
248+
def test_check_grad(self):
249+
pass
250+
251+
232252
class TestGruOpError(unittest.TestCase):
233253
def test_errors(self):
234254
with program_guard(Program(), Program()):

python/paddle/fluid/tests/unittests/test_lstm_op.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,11 @@ def set_lod(self):
307307
self.lod = [[2, 0, 4]]
308308

309309

310-
@skip_check_grad_ci(
311-
reason="This unittest is used to check whether the op run correctly "
312-
"in inference time, no need to calculate gradient.")
313310
class TestLstmOpInference(TestLstmOp):
314311
def set_is_test(self):
315312
self.is_test = True
316313

317-
# avoid checking gradient
314+
# avoid checking gradient
318315
def test_check_grad(self):
319316
pass
320317

0 commit comments

Comments
 (0)