Skip to content

Commit a9a731e

Browse files
Aurelius84guoshengCS
authored andcommitted
Fix test_lstm unittest failed and Add more unittest (#28029)
* fix test_lstm unittest failed * add more unittest * modify cmakelist * fix judgement
1 parent 2584ff7 commit a9a731e

2 files changed

Lines changed: 41 additions & 4 deletions

File tree

python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def _extract_indeed_params_buffers(class_instance):
627627
"""
628628
params = list(get_parameters(class_instance).values())
629629
buffers = list(get_buffers(class_instance).values())
630-
buffers = [buffer for buffer in buffers if buffer.shape != []]
630+
buffers = [buffer for buffer in buffers if len(buffer.shape) != 0]
631631

632632
return params + buffers
633633

python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(self, in_channels, hidden_size):
2424
self.lstm = nn.LSTM(
2525
in_channels, hidden_size, direction='bidirectional', num_layers=2)
2626

27-
@paddle.jit.to_static
2827
def forward(self, x):
2928
x, _ = self.lstm(x)
3029
return x
@@ -39,6 +38,7 @@ def run_lstm(self, to_static):
3938
paddle.static.default_startup_program().random_seed = 1001
4039

4140
net = Net(12, 2)
41+
net = paddle.jit.to_static(net)
4242
x = paddle.zeros((2, 10, 12))
4343
y = net(paddle.to_tensor(x))
4444
return y.numpy()
@@ -54,16 +54,17 @@ def test_lstm_to_static(self):
5454
def test_save_in_eval(self):
5555
paddle.jit.ProgramTranslator().enable(True)
5656
net = Net(12, 2)
57+
x = paddle.randn((2, 10, 12))
58+
dygraph_out = net(x)
5759
# switch eval mode firstly
5860
net.eval()
61+
5962
net = paddle.jit.to_static(
6063
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
6164
paddle.jit.save(net, 'simple_lstm')
6265
# load saved model
6366
load_net = paddle.jit.load('simple_lstm')
6467

65-
x = paddle.randn((2, 10, 12))
66-
dygraph_out = net(x)
6768
static_out = load_net(x)
6869
self.assertTrue(
6970
np.allclose(dygraph_out.numpy(), static_out.numpy()),
@@ -78,5 +79,41 @@ def test_save_in_eval(self):
7879
train_out))
7980

8081

82+
class LinearNet(nn.Layer):
83+
def __init__(self):
84+
super(LinearNet, self).__init__()
85+
self.fc = nn.Linear(10, 12)
86+
self.dropout = nn.Dropout(0.5)
87+
88+
@paddle.jit.to_static
89+
def forward(self, x):
90+
y = self.fc(x)
91+
y = self.dropout(y)
92+
return y
93+
94+
95+
class TestSaveInEvalMode(unittest.TestCase):
96+
def test_save_in_eval(self):
97+
paddle.jit.ProgramTranslator().enable(True)
98+
net = LinearNet()
99+
# switch eval mode firstly
100+
net.eval()
101+
# save directly
102+
net = paddle.jit.to_static(
103+
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10])])
104+
paddle.jit.save(net, 'linear_net')
105+
# load saved model
106+
load_net = paddle.jit.load('linear_net')
107+
108+
x = paddle.randn((2, 10))
109+
eval_out = net(x)
110+
111+
infer_out = load_net(x)
112+
self.assertTrue(
113+
np.allclose(eval_out.numpy(), infer_out.numpy()),
114+
msg='eval_out is {}\n infer_out is \n{}'.format(eval_out,
115+
infer_out))
116+
117+
81118
if __name__ == "__main__":
82119
unittest.main()

0 commit comments

Comments
 (0)