@@ -12,6 +12,7 @@ def __init__(self, api_config, **kwargs):
1212 super ().__init__ (api_config )
1313 self .test_amp = kwargs .get ("test_amp" , False )
1414 self .custom_device_type = self ._get_first_custom_device_type ()
15+ self .generate_failed_tests = kwargs .get ("generate_failed_tests" , False )
1516 if self .check_custom_device_available ():
1617 self .custom_device_id = 0
1718 if self .check_xpu_available ():
@@ -260,6 +261,28 @@ def test(self):
260261 if cpu_output is None :
261262 print ("[cpu execution failed]" , self .api_config .config , flush = True )
262263 write_to_log ("paddle_error" , self .api_config .config )
264+ # CPU 前向/反向执行失败时,如果开启了生成失败用例,则生成可复现单测
265+ if self .generate_failed_tests :
266+ try :
267+ from .test_file_generator import generate_reproducible_test_file
268+
269+ error_info = {
270+ "error_type" : "paddle_error" ,
271+ "stage" : "forward" ,
272+ "need_backward" : self .need_check_grad (),
273+ }
274+ test_file_path = generate_reproducible_test_file (
275+ self .api_config ,
276+ error_info ,
277+ test_amp = self .test_amp ,
278+ target_device = "cpu" ,
279+ device_id = 0 ,
280+ test_instance = self ,
281+ )
282+ if test_file_path :
283+ print (f"[Generated test file] { test_file_path } " , flush = True )
284+ except Exception as e :
285+ print (f"[Error generating test file] { e } " , flush = True )
263286 return
264287
265288 # 6. Run API on target device (including forward and backward)
@@ -271,6 +294,28 @@ def test(self):
271294 flush = True ,
272295 )
273296 write_to_log ("paddle_error" , self .api_config .config )
297+ # 目标设备前向/反向执行失败,同样生成失败用例
298+ if self .generate_failed_tests :
299+ try :
300+ from .test_file_generator import generate_reproducible_test_file
301+
302+ error_info = {
303+ "error_type" : "paddle_error" ,
304+ "stage" : "forward" ,
305+ "need_backward" : self .need_check_grad (),
306+ }
307+ test_file_path = generate_reproducible_test_file (
308+ self .api_config ,
309+ error_info ,
310+ test_amp = self .test_amp ,
311+ target_device = target_device ,
312+ device_id = device_id ,
313+ test_instance = self ,
314+ )
315+ if test_file_path :
316+ print (f"[Generated test file] { test_file_path } " , flush = True )
317+ except Exception as e :
318+ print (f"[Error generating test file] { e } " , flush = True )
274319 return
275320
276321 # 7. Compare forward results
@@ -310,3 +355,46 @@ def test(self):
310355 else :
311356 print ("[Fail]" , self .api_config .config , flush = True )
312357 write_to_log ("accuracy_error" , self .api_config .config )
358+ # 生成可复现的单测文件
359+ if self .generate_failed_tests :
360+ try :
361+ from .test_file_generator import generate_reproducible_test_file
362+
363+ # 确定目标设备
364+ if self .check_xpu_available ():
365+ target_device = "xpu"
366+ device_id = self .xpu_device_id
367+ elif self .check_custom_device_available ():
368+ target_device = self .custom_device_type
369+ device_id = self .custom_device_id
370+ else :
371+ target_device = "cpu"
372+ device_id = 0
373+
374+ # 确定失败阶段
375+ stage = "unknown"
376+ if not forward_pass :
377+ stage = "forward"
378+ elif not backward_pass :
379+ stage = "backward"
380+
381+ error_info = {
382+ "error_type" : "accuracy_error" ,
383+ "stage" : stage ,
384+ "need_backward" : self .need_check_grad (),
385+ }
386+
387+ # 生成测试文件
388+ test_file_path = generate_reproducible_test_file (
389+ self .api_config ,
390+ error_info ,
391+ test_amp = self .test_amp ,
392+ target_device = target_device ,
393+ device_id = device_id ,
394+ test_instance = self ,
395+ )
396+
397+ if test_file_path :
398+ print (f"[Generated test file] { test_file_path } " , flush = True )
399+ except Exception as e :
400+ print (f"[Error generating test file] { e } " , flush = True )
0 commit comments