@@ -24,7 +24,6 @@ def __init__(self, in_channels, hidden_size):
2424 self .lstm = nn .LSTM (
2525 in_channels , hidden_size , direction = 'bidirectional' , num_layers = 2 )
2626
27- @paddle .jit .to_static
2827 def forward (self , x ):
2928 x , _ = self .lstm (x )
3029 return x
@@ -39,6 +38,7 @@ def run_lstm(self, to_static):
3938 paddle .static .default_startup_program ().random_seed = 1001
4039
4140 net = Net (12 , 2 )
41+ net = paddle .jit .to_static (net )
4242 x = paddle .zeros ((2 , 10 , 12 ))
4343 y = net (paddle .to_tensor (x ))
4444 return y .numpy ()
@@ -54,16 +54,17 @@ def test_lstm_to_static(self):
5454 def test_save_in_eval (self ):
5555 paddle .jit .ProgramTranslator ().enable (True )
5656 net = Net (12 , 2 )
57+ x = paddle .randn ((2 , 10 , 12 ))
58+ dygraph_out = net (x )
5759 # switch eval mode firstly
5860 net .eval ()
61+
5962 net = paddle .jit .to_static (
6063 net , input_spec = [paddle .static .InputSpec (shape = [- 1 , 10 , 12 ])])
6164 paddle .jit .save (net , 'simple_lstm' )
6265 # load saved model
6366 load_net = paddle .jit .load ('simple_lstm' )
6467
65- x = paddle .randn ((2 , 10 , 12 ))
66- dygraph_out = net (x )
6768 static_out = load_net (x )
6869 self .assertTrue (
6970 np .allclose (dygraph_out .numpy (), static_out .numpy ()),
@@ -78,5 +79,41 @@ def test_save_in_eval(self):
7879 train_out ))
7980
8081
82+ class LinearNet (nn .Layer ):
83+ def __init__ (self ):
84+ super (LinearNet , self ).__init__ ()
85+ self .fc = nn .Linear (10 , 12 )
86+ self .dropout = nn .Dropout (0.5 )
87+
88+ @paddle .jit .to_static
89+ def forward (self , x ):
90+ y = self .fc (x )
91+ y = self .dropout (y )
92+ return y
93+
94+
95+ class TestSaveInEvalMode (unittest .TestCase ):
96+ def test_save_in_eval (self ):
97+ paddle .jit .ProgramTranslator ().enable (True )
98+ net = LinearNet ()
99+ # switch eval mode firstly
100+ net .eval ()
101+ # save directly
102+ net = paddle .jit .to_static (
103+ net , input_spec = [paddle .static .InputSpec (shape = [- 1 , 10 ])])
104+ paddle .jit .save (net , 'linear_net' )
105+ # load saved model
106+ load_net = paddle .jit .load ('linear_net' )
107+
108+ x = paddle .randn ((2 , 10 ))
109+ eval_out = net (x )
110+
111+ infer_out = load_net (x )
112+ self .assertTrue (
113+ np .allclose (eval_out .numpy (), infer_out .numpy ()),
114+ msg = 'eval_out is {}\n infer_out is \n {}' .format (eval_out ,
115+ infer_out ))
116+
117+
81118if __name__ == "__main__" :
82119 unittest .main ()
0 commit comments