Skip to content

Commit f13dcfb

Browse files
authored
Add AsExtra for transpose, lstm, gru (#35317)
* Add AsExtra for transpose * add AsExtra for lstm op * add AsExtra for gru
1 parent b333dac commit f13dcfb

File tree

7 files changed

+153
-43
lines changed

7 files changed

+153
-43
lines changed

paddle/fluid/operators/gru_op.cc

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel {
3333
void InferShape(framework::InferShapeContext* ctx) const override {
3434
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
3535
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
36-
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
37-
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
38-
"BatchResetHiddenPrev", "GRU");
39-
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
40-
"GRU");
4136
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
42-
37+
bool is_test = ctx->Attrs().Get<bool>("is_test");
38+
if (!is_test) {
39+
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
40+
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
41+
"BatchResetHiddenPrev", "GRU");
42+
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
43+
"GRU");
44+
}
4345
auto input_dims = ctx->GetInputDim("Input");
4446
auto weight_dims = ctx->GetInputDim("Weight");
4547
int input_size = input_dims[1];
@@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel {
8486
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
8587
bias_height, bias_width, frame_size * 3));
8688
}
87-
ctx->SetOutputDim("BatchGate", input_dims);
88-
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
89-
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
89+
if (!is_test) {
90+
ctx->SetOutputDim("BatchGate", input_dims);
91+
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
92+
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
93+
}
9094
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
9195
ctx->ShareLoD("Input", "Hidden");
9296
}
@@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
124128
"organized in batches. The LoD size is 2. The first LoD contains "
125129
"the batch offsets and the second LoD contains the indexes in "
126130
"the raw sequence data.")
127-
.AsIntermediate();
131+
.AsIntermediate()
132+
.AsExtra();
128133
AddOutput(
129134
"BatchResetHiddenPrev",
130135
"(LoDTensor) The reset hidden state LoDTensor organized in batches. "
131136
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
132137
"with `BatchGate`.")
133-
.AsIntermediate();
138+
.AsIntermediate()
139+
.AsExtra();
134140
AddOutput(
135141
"BatchHidden",
136142
"(LoDTensor) The hidden state LoDTensor organized in batches. "
137143
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
138144
"with `BatchGate`.")
139-
.AsIntermediate();
145+
.AsIntermediate()
146+
.AsExtra();
140147
AddOutput(
141148
"Hidden",
142149
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
@@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
155162
"(bool, default: False) "
156163
"whether to compute reversed GRU.")
157164
.SetDefault(false);
165+
AddAttr<bool>("is_test", "True if in test phase.")
166+
.SetDefault(false)
167+
.AsExtra();
158168
AddAttr<bool>("origin_mode",
159169
"bool"
160170
"use origin mode in article https://arxiv.org/abs/1412.3555")
@@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel<T> {
269279
public:
270280
void BatchCompute(const framework::ExecutionContext& context) const {
271281
using DeviceContext = paddle::platform::CPUDeviceContext;
282+
using LodTensorPtr = LoDTensor*;
283+
bool is_test = context.Attr<bool>("is_test");
284+
272285
bool origin_mode = context.Attr<bool>("origin_mode");
273286
auto* input = context.Input<LoDTensor>("Input");
274287
auto* h0 = context.Input<Tensor>("H0");
275288
auto* weight = context.Input<Tensor>("Weight");
276289
const T* weight_data = weight->data<T>();
277290
auto* bias = context.Input<Tensor>("Bias");
278-
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
279-
batch_gate->mutable_data<T>(context.GetPlace());
280-
auto* batch_reset_hidden_prev =
281-
context.Output<LoDTensor>("BatchResetHiddenPrev");
282-
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
283-
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
284-
batch_hidden->mutable_data<T>(context.GetPlace());
285291
auto* hidden = context.Output<LoDTensor>("Hidden");
286292
hidden->mutable_data<T>(context.GetPlace());
287293

294+
auto input_dims = input->dims();
288295
auto hidden_dims = hidden->dims();
289296

297+
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
298+
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
299+
if (is_test) {
300+
batch_gate = &batch_gate_tmp;
301+
batch_gate->Resize(input_dims);
302+
303+
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
304+
batch_reset_hidden_prev->Resize(hidden_dims);
305+
306+
batch_hidden = &batch_hidden_tmp;
307+
batch_hidden->Resize(hidden_dims);
308+
} else {
309+
batch_gate = context.Output<LoDTensor>("BatchGate");
310+
batch_hidden = context.Output<LoDTensor>("BatchHidden");
311+
batch_reset_hidden_prev =
312+
context.Output<LoDTensor>("BatchResetHiddenPrev");
313+
}
314+
batch_gate->mutable_data<T>(context.GetPlace());
315+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
316+
batch_hidden->mutable_data<T>(context.GetPlace());
317+
290318
bool is_reverse = context.Attr<bool>("is_reverse");
291319
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
292320
auto& dev_ctx = context.template device_context<DeviceContext>();

paddle/fluid/operators/gru_op.cu.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,42 @@ template <typename DeviceContext, typename T>
2828
class GRUKernel : public framework::OpKernel<T> {
2929
public:
3030
void BatchCompute(const framework::ExecutionContext& context) const {
31+
using LodTensorPtr = LoDTensor*;
32+
33+
bool is_test = context.Attr<bool>("is_test");
3134
bool origin_mode = context.Attr<bool>("origin_mode");
3235
auto* input = context.Input<LoDTensor>("Input");
3336
auto* h0 = context.Input<Tensor>("H0");
3437
auto* weight = context.Input<Tensor>("Weight");
3538
const T* weight_data = weight->data<T>();
3639
auto* bias = context.Input<Tensor>("Bias");
37-
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
38-
batch_gate->mutable_data<T>(context.GetPlace());
39-
auto* batch_reset_hidden_prev =
40-
context.Output<LoDTensor>("BatchResetHiddenPrev");
41-
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
42-
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
43-
batch_hidden->mutable_data<T>(context.GetPlace());
4440
auto* hidden = context.Output<LoDTensor>("Hidden");
4541
hidden->mutable_data<T>(context.GetPlace());
4642

43+
auto input_dims = input->dims();
4744
auto hidden_dims = hidden->dims();
4845

46+
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
47+
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
48+
if (is_test) {
49+
batch_gate = &batch_gate_tmp;
50+
batch_gate->Resize(input_dims);
51+
52+
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
53+
batch_reset_hidden_prev->Resize(hidden_dims);
54+
55+
batch_hidden = &batch_hidden_tmp;
56+
batch_hidden->Resize(hidden_dims);
57+
} else {
58+
batch_gate = context.Output<LoDTensor>("BatchGate");
59+
batch_hidden = context.Output<LoDTensor>("BatchHidden");
60+
batch_reset_hidden_prev =
61+
context.Output<LoDTensor>("BatchResetHiddenPrev");
62+
}
63+
batch_gate->mutable_data<T>(context.GetPlace());
64+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
65+
batch_hidden->mutable_data<T>(context.GetPlace());
66+
4967
bool is_reverse = context.Attr<bool>("is_reverse");
5068
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
5169
auto& dev_ctx = context.template device_context<DeviceContext>();

paddle/fluid/operators/lstm_op.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@ class LSTMOp : public framework::OperatorWithKernel {
3030

3131
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
3232
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
33-
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
34-
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
35-
"BatchCellPreAct", "LSTM");
3633

34+
bool is_test = ctx->Attrs().Get<bool>("is_test");
35+
36+
if (!is_test) {
37+
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate",
38+
"LSTM");
39+
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
40+
"BatchCellPreAct", "LSTM");
41+
}
3742
auto in_dims = ctx->GetInputDim("Input");
3843
PADDLE_ENFORCE_EQ(
3944
in_dims.size(), 2,
@@ -103,8 +108,10 @@ class LSTMOp : public framework::OperatorWithKernel {
103108
framework::DDim out_dims({in_dims[0], frame_size});
104109
ctx->SetOutputDim("Hidden", out_dims);
105110
ctx->SetOutputDim("Cell", out_dims);
106-
ctx->SetOutputDim("BatchGate", in_dims);
107-
ctx->SetOutputDim("BatchCellPreAct", out_dims);
111+
if (!is_test) {
112+
ctx->SetOutputDim("BatchGate", in_dims);
113+
ctx->SetOutputDim("BatchCellPreAct", out_dims);
114+
}
108115
ctx->ShareLoD("Input", "Hidden");
109116
ctx->ShareLoD("Input", "Cell");
110117
}
@@ -164,11 +171,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
164171
"LoD is the batch offsets and the second LoD contains the "
165172
"indexes, which denote the position of reorganized sequence "
166173
"in the raw input.")
167-
.AsIntermediate();
174+
.AsIntermediate()
175+
.AsExtra();
168176
AddOutput("BatchCellPreAct",
169177
"(LoDTensor) This LoDTensor is obtained in the forward and used "
170178
"in the backward.")
171-
.AsIntermediate();
179+
.AsIntermediate()
180+
.AsExtra();
172181
AddAttr<bool>("use_peepholes",
173182
"(bool, default: True) "
174183
"whether to enable diagonal/peephole connections.")
@@ -177,6 +186,9 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
177186
"(bool, default: False) "
178187
"whether to compute reversed LSTM.")
179188
.SetDefault(false);
189+
AddAttr<bool>("is_test", "True if in test phase.")
190+
.SetDefault(false)
191+
.AsExtra();
180192
AddAttr<std::string>(
181193
"gate_activation",
182194
"(string, default: sigmoid)"

paddle/fluid/operators/lstm_op.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,23 @@ template <typename DeviceContext, typename T>
4040
class LSTMKernel : public framework::OpKernel<T> {
4141
public:
4242
void Compute(const framework::ExecutionContext& ctx) const override {
43+
bool is_test = ctx.Attr<bool>("is_test");
44+
4345
auto* input = ctx.Input<LoDTensor>("Input");
4446
auto* weight = ctx.Input<Tensor>("Weight");
4547
auto* bias = ctx.Input<Tensor>("Bias");
4648

4749
auto* hidden_t0 = ctx.Input<Tensor>("H0");
4850
auto* cell_t0 = ctx.Input<Tensor>("C0");
4951

50-
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
52+
LoDTensor* batch_gate = nullptr;
53+
LoDTensor batch_gate_temp;
54+
if (is_test) {
55+
batch_gate = &batch_gate_temp;
56+
batch_gate->Resize(input->dims());
57+
} else {
58+
batch_gate = ctx.Output<LoDTensor>("BatchGate");
59+
}
5160
batch_gate->mutable_data<T>(ctx.GetPlace());
5261
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
5362
hidden_out->mutable_data<T>(ctx.GetPlace());
@@ -99,8 +108,13 @@ class LSTMKernel : public framework::OpKernel<T> {
99108
}
100109

101110
// Use the local variable as here.
102-
LoDTensor batch_hidden, batch_cell;
103-
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
111+
LoDTensor batch_hidden, batch_cell, batch_cell_pre_act_temp;
112+
LoDTensor* batch_cell_pre_act;
113+
if (is_test) {
114+
batch_cell_pre_act = &batch_cell_pre_act_temp;
115+
} else {
116+
batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
117+
}
104118
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
105119
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
106120
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());

paddle/fluid/operators/transpose_op.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
119119
"tensor's axes according to the values given.");
120120
AddAttr<bool>("use_mkldnn",
121121
"(bool, default false) Only used in mkldnn kernel")
122-
.SetDefault(false);
122+
.SetDefault(false)
123+
.AsExtra();
123124
AddAttr<std::string>(
124125
"data_format",
125126
"(string, default NCHW) Only used in "
126127
"An optional string from: \"NHWC\", \"NCHW\". "
127128
"Defaults to \"NHWC\". Specify the data format of the output data, "
128129
"the input will be transformed automatically. ")
129-
.SetDefault("AnyLayout");
130+
.SetDefault("AnyLayout")
131+
.AsExtra();
130132
AddAttr<bool>(
131133
"use_quantizer",
132134
"(bool, default false) "
@@ -262,7 +264,9 @@ class Transpose2OpMaker : public TransposeOpMaker {
262264
public:
263265
void Make() override {
264266
TransposeOpMaker::Make();
265-
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate();
267+
AddOutput("XShape", "(Tensor)The output tensor.")
268+
.AsIntermediate()
269+
.AsExtra();
266270
}
267271
};
268272

python/paddle/fluid/tests/unittests/test_gru_op.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import math
2020
import functools
21-
from op_test import OpTest
21+
from op_test import OpTest, skip_check_grad_ci
2222
from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION
2323
from paddle import fluid
2424
from paddle.fluid import Program, program_guard
@@ -106,6 +106,9 @@ class TestGRUOp(OpTest):
106106
def set_confs(self):
107107
pass
108108

109+
def set_is_test(self):
110+
self.is_test = False
111+
109112
def setUp(self):
110113
self.op_type = "gru"
111114
self.lod = [[2, 4, 3]]
@@ -118,6 +121,7 @@ def setUp(self):
118121
self.dtype = 'float64'
119122
self.origin_mode = False
120123
self.set_confs()
124+
self.set_is_test()
121125

122126
T = sum(self.lod[0])
123127
N = len(self.lod[0])
@@ -153,7 +157,8 @@ def setUp(self):
153157
'activation': self.act_state,
154158
'gate_activation': self.act_gate,
155159
'is_reverse': self.is_reverse,
156-
'origin_mode': self.origin_mode
160+
'origin_mode': self.origin_mode,
161+
'is_test': self.is_test
157162
}
158163

159164
def test_check_output(self):
@@ -229,6 +234,21 @@ def set_confs(self):
229234
self.origin_mode = True
230235

231236

237+
class TestGRUOpInference(TestGRUOp):
238+
def set_is_test(self):
239+
self.is_test = True
240+
241+
def test_check_output(self):
242+
new_outputs = {}
243+
new_outputs['Hidden'] = self.outputs['Hidden']
244+
self.outputs = new_outputs
245+
super(TestGRUOpInference, self).test_check_output()
246+
247+
# avoid checking gradient
248+
def test_check_grad(self):
249+
pass
250+
251+
232252
class TestGruOpError(unittest.TestCase):
233253
def test_errors(self):
234254
with program_guard(Program(), Program()):

0 commit comments

Comments
 (0)