1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import os
17- import tempfile
1816import unittest
1917
2018import numpy as np
2119
2220import paddle
23- from paddle import base , nn
21+ from paddle import nn
2422
2523
2624class SimpleFCLayer (nn .Layer ):
@@ -47,226 +45,5 @@ def forward(self, x):
4745 return [fc , [None , 2 ]]
4846
4947
50- class TestTracedLayerErrMsg (unittest .TestCase ):
51- def setUp (self ):
52- self .batch_size = 4
53- self .feature_size = 3
54- self .fc_size = 2
55- self .layer = self ._train_simple_net ()
56- self .type_str = 'class'
57- self .temp_dir = tempfile .TemporaryDirectory ()
58-
59- def tearDown (self ):
60- self .temp_dir .cleanup ()
61-
62- def test_trace_err (self ):
63- if base .framework .in_dygraph_mode ():
64- return
65- with base .dygraph .guard ():
66- in_x = paddle .to_tensor (
67- np .random .random ((self .batch_size , self .feature_size )).astype (
68- 'float32'
69- )
70- )
71-
72- with self .assertRaises (AssertionError ) as e :
73- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
74- None , [in_x ]
75- )
76- self .assertEqual (
77- "The type of 'layer' in paddle.jit.TracedLayer.trace must be paddle.nn.Layer, but received <{} 'NoneType'>." .format (
78- self .type_str
79- ),
80- str (e .exception ),
81- )
82- with self .assertRaises (TypeError ) as e :
83- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
84- self .layer , 3
85- )
86- self .assertEqual (
87- "The type of 'each element of inputs' in paddle.jit.TracedLayer.trace must be base.Variable, but received <{} 'int'>." .format (
88- self .type_str
89- ),
90- str (e .exception ),
91- )
92- with self .assertRaises (TypeError ) as e :
93- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
94- self .layer , [True , 1 ]
95- )
96- self .assertEqual (
97- "The type of 'each element of inputs' in paddle.jit.TracedLayer.trace must be base.Variable, but received <{} 'bool'>." .format (
98- self .type_str
99- ),
100- str (e .exception ),
101- )
102-
103- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
104- self .layer , [in_x ]
105- )
106-
107- def test_set_strategy_err (self ):
108- if base .framework .in_dygraph_mode ():
109- return
110- with base .dygraph .guard ():
111- in_x = paddle .to_tensor (
112- np .random .random ((self .batch_size , self .feature_size )).astype (
113- 'float32'
114- )
115- )
116- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
117- self .layer , [in_x ]
118- )
119-
120- with self .assertRaises (AssertionError ) as e :
121- traced_layer .set_strategy (1 , base .ExecutionStrategy ())
122- self .assertEqual (
123- "The type of 'build_strategy' in paddle.jit.TracedLayer.set_strategy must be base.BuildStrategy, but received <{} 'int'>." .format (
124- self .type_str
125- ),
126- str (e .exception ),
127- )
128-
129- with self .assertRaises (AssertionError ) as e :
130- traced_layer .set_strategy (base .BuildStrategy (), False )
131- self .assertEqual (
132- "The type of 'exec_strategy' in paddle.jit.TracedLayer.set_strategy must be base.ExecutionStrategy, but received <{} 'bool'>." .format (
133- self .type_str
134- ),
135- str (e .exception ),
136- )
137-
138- traced_layer .set_strategy (build_strategy = base .BuildStrategy ())
139- traced_layer .set_strategy (exec_strategy = base .ExecutionStrategy ())
140- traced_layer .set_strategy (
141- base .BuildStrategy (), base .ExecutionStrategy ()
142- )
143-
144- def test_save_inference_model_err (self ):
145- if base .framework .in_dygraph_mode ():
146- return
147- with base .dygraph .guard ():
148- in_x = paddle .to_tensor (
149- np .random .random ((self .batch_size , self .feature_size )).astype (
150- 'float32'
151- )
152- )
153- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
154- self .layer , [in_x ]
155- )
156-
157- path = os .path .join (self .temp_dir .name , './traced_layer_err_msg' )
158- with self .assertRaises (TypeError ) as e :
159- traced_layer .save_inference_model ([0 ])
160- self .assertEqual (
161- "The type of 'path' in paddle.jit.TracedLayer.save_inference_model must be <{} 'str'>, but received <{} 'list'>. " .format (
162- self .type_str , self .type_str
163- ),
164- str (e .exception ),
165- )
166- with self .assertRaises (TypeError ) as e :
167- traced_layer .save_inference_model (path , [0 ], [None ])
168- self .assertEqual (
169- "The type of 'each element of fetch' in paddle.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. " .format (
170- self .type_str , self .type_str
171- ),
172- str (e .exception ),
173- )
174- with self .assertRaises (TypeError ) as e :
175- traced_layer .save_inference_model (path , [0 ], False )
176- self .assertEqual (
177- "The type of 'fetch' in paddle.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. " .format (
178- self .type_str , self .type_str , self .type_str
179- ),
180- str (e .exception ),
181- )
182- with self .assertRaises (TypeError ) as e :
183- traced_layer .save_inference_model (path , [None ], [0 ])
184- self .assertEqual (
185- "The type of 'each element of feed' in paddle.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. " .format (
186- self .type_str , self .type_str
187- ),
188- str (e .exception ),
189- )
190- with self .assertRaises (TypeError ) as e :
191- traced_layer .save_inference_model (path , True , [0 ])
192- self .assertEqual (
193- "The type of 'feed' in paddle.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. " .format (
194- self .type_str , self .type_str , self .type_str
195- ),
196- str (e .exception ),
197- )
198- with self .assertRaises (ValueError ) as e :
199- traced_layer .save_inference_model ("" )
200- self .assertEqual (
201- "The input path MUST be format of dirname/file_prefix [dirname\\ file_prefix in Windows system], "
202- "but received file_prefix is empty string." ,
203- str (e .exception ),
204- )
205-
206- traced_layer .save_inference_model (path )
207-
208- def _train_simple_net (self ):
209- layer = None
210- with base .dygraph .guard ():
211- layer = SimpleFCLayer (
212- self .feature_size , self .batch_size , self .fc_size
213- )
214- optimizer = paddle .optimizer .SGD (
215- learning_rate = 1e-3 , parameters = layer .parameters ()
216- )
217-
218- for i in range (5 ):
219- in_x = paddle .to_tensor (
220- np .random .random (
221- (self .batch_size , self .feature_size )
222- ).astype ('float32' )
223- )
224- dygraph_out = layer (in_x )
225- loss = paddle .mean (dygraph_out )
226- loss .backward ()
227- optimizer .minimize (loss )
228- return layer
229-
230-
231- class TestOutVarWithNoneErrMsg (unittest .TestCase ):
232- def test_linear_net_with_none (self ):
233- if base .framework .in_dygraph_mode ():
234- return
235- model = LinearNetWithNone (100 , 16 )
236- in_x = paddle .to_tensor (np .random .random ((4 , 100 )).astype ('float32' ))
237- with self .assertRaises (TypeError ):
238- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
239- model , [in_x ]
240- )
241-
242-
243- class TestTracedLayerSaveInferenceModel (unittest .TestCase ):
244- """test save_inference_model will automatically create non-exist dir"""
245-
246- def setUp (self ):
247- self .temp_dir = tempfile .TemporaryDirectory ()
248- self .save_path = os .path .join (self .temp_dir .name , "./nonexist_dir/fc" )
249- import shutil
250-
251- if os .path .exists (os .path .dirname (self .save_path )):
252- shutil .rmtree (os .path .dirname (self .save_path ))
253-
254- def tearDown (self ):
255- self .temp_dir .cleanup ()
256-
257- def test_mkdir_when_input_path_non_exist (self ):
258- if base .framework .in_dygraph_mode ():
259- return
260- fc_layer = SimpleFCLayer (3 , 4 , 2 )
261- input_var = paddle .to_tensor (np .random .random ([4 , 3 ]).astype ('float32' ))
262- with base .dygraph .guard ():
263- dygraph_out , traced_layer = base .dygraph .TracedLayer .trace (
264- fc_layer , inputs = [input_var ]
265- )
266- self .assertFalse (os .path .exists (os .path .dirname (self .save_path )))
267- traced_layer .save_inference_model (self .save_path )
268- self .assertTrue (os .path .exists (os .path .dirname (self .save_path )))
269-
270-
27148if __name__ == '__main__' :
27249 unittest .main ()
0 commit comments