@@ -248,5 +248,49 @@ def test_with_error(self):
248248 del os .environ ['FLAGS_USE_STANDALONE_EXECUTOR' ]
249249
250250
251+ class TestException (unittest .TestCase ):
252+ def setUp (self ):
253+ self .place = paddle .CPUPlace ()
254+
255+ def build_program (self ):
256+ main_program = paddle .static .Program ()
257+ startup_program = paddle .static .Program ()
258+ with paddle .static .program_guard (main_program , startup_program ):
259+ w = paddle .rand ([10 , 20 ])
260+ ids = paddle .static .data (name = "id" , shape = [5 ], dtype = 'int64' )
261+ emb = paddle .nn .functional .embedding (
262+ x = ids , weight = w , sparse = False , name = "embedding" )
263+
264+ return main_program , startup_program , emb
265+
266+ def _run (self , feeds ):
267+ paddle .seed (2020 )
268+
269+ main_program , startup_program , fetch_vars = self .build_program ()
270+
271+ exe = paddle .static .Executor (self .place )
272+ exe .run (startup_program )
273+
274+ for feed in feeds :
275+ out = exe .run (main_program , feed = feed , fetch_list = fetch_vars )
276+
277+ return out
278+
279+ def run_new_executor (self , feed ):
280+ os .environ ['FLAGS_USE_STANDALONE_EXECUTOR' ] = '1'
281+ out = self ._run (feed )
282+ del os .environ ['FLAGS_USE_STANDALONE_EXECUTOR' ]
283+ return out
284+
285+ def test_exception (self ):
286+ feed = [{
287+ 'id' : np .array ([1 , 2 , 3 , 4 , 5 ]).astype (np .int64 )
288+ }, {
289+ 'id' : np .array ([1 , 2 , 3 , 4 , 11 ]).astype (np .int64 )
290+ }]
291+ out = self .run_new_executor (feed )
292+ return out
293+
294+
251295if __name__ == "__main__" :
252296 unittest .main ()
0 commit comments