diff --git a/easybuild/easyblocks/p/pytorch.py b/easybuild/easyblocks/p/pytorch.py index fcf264095d4..ec4ee9ad97a 100644 --- a/easybuild/easyblocks/p/pytorch.py +++ b/easybuild/easyblocks/p/pytorch.py @@ -40,7 +40,6 @@ from easybuild.tools.filetools import symlink, apply_regex_substitutions from easybuild.tools.modules import get_software_root, get_software_version from easybuild.tools.systemtools import POWER, get_cpu_architecture -from easybuild.tools.utilities import nub class EB_PyTorch(PythonPackage): @@ -264,17 +263,68 @@ def test_step(self): (tests_out, tests_ec) = super(EB_PyTorch, self).test_step(return_output_ec=True) ran_tests_hits = re.findall(r"^Ran (?P[0-9]+) tests in", tests_out, re.M) - test_cnt = sum(int(hit) for hit in ran_tests_hits) - - failed_tests = nub(re.findall(r"^(?P.*) failed!(?: Received signal: \w+)?\s*$", tests_out, re.M)) - failed_test_cnt = len(failed_tests) - - if failed_test_cnt: + test_cnt = 0 + for hit in ran_tests_hits: + test_cnt += int(hit) + + # Get matches to create clear summary report, greps for patterns like: + # FAILED (errors=10, skipped=190, expected failures=6) + # test_fx failed! + regex = r"^Ran (?P[0-9]+) tests.*$\n\nFAILED \((?P.*)\)$\n(?:^(?:(?!failed!).)*$\n)*(?P.*) failed!$" # noqa: E501 + summary_matches = re.findall(regex, tests_out, re.M) + + # Get matches to create clear summary report, greps for patterns like: + # ===================== 2 failed, 128 passed, 2 skipped, 2 warnings in 3.43s ===================== + regex = r"^=+ (?P.*) in [0-9]+\.*[0-9]*[a-zA-Z]* =+$\n(?P.*) failed!$" + summary_matches_pattern2 = re.findall(regex, tests_out, re.M) + + # Count failures and errors + def get_count_for_pattern(regex, text): + match = re.findall(regex, text, re.M) + if len(match) == 0: + return 0 + elif len(match) == 1: + return int(match[0]) + elif len(match) > 1: + # Shouldn't happen, but means something went wrong with the regular expressions. + # Throw warning, as the build might be fine, no need to error on this. + warn_msg = "Error in counting the number of test failures in the output of the PyTorch test suite.\n" + warn_msg += "Please check the EasyBuild log to verify the number of failures (if any) was acceptable." + print_warning(warn_msg) + + failure_cnt = 0 + error_cnt = 0 + # Loop over first pattern to count failures/errors: + for summary in summary_matches: + failures = get_count_for_pattern(r"^.*(?[0-9]+).*$", summary[1]) + failure_cnt += failures + errs = get_count_for_pattern(r"^.*errors=(?P[0-9]+).*$", summary[1]) + error_cnt += errs + + # Loop over the second pattern to count failures/errors + for summary in summary_matches_pattern2: + failures = get_count_for_pattern(r"^.*(?P[0-9]+) failed.*$", summary[0]) + failure_cnt += failures + errs = get_count_for_pattern(r"^.*(?P[0-9]+) error.*$", summary[0]) + error_cnt += errs + + # Calculate total number of unsuccesful tests + failed_test_cnt = failure_cnt + error_cnt + + if failed_test_cnt > 0: max_failed_tests = self.cfg['max_failed_tests'] - test_or_tests = 'tests' if failed_test_cnt > 1 else 'test' - msg = "%d %s (out of %d) failed:\n" % (failed_test_cnt, test_or_tests, test_cnt) - msg += '\n'.join('* %s' % t for t in sorted(failed_tests)) + failure_or_failures = 'failures' if failure_cnt > 1 else 'failure' + error_or_errors = 'errors' if error_cnt > 1 else 'error' + msg = "%d test %s, %d test %s (out of %d):\n" % ( + failure_cnt, failure_or_failures, error_cnt, error_or_errors, test_cnt + ) + for summary in summary_matches_pattern2: + msg += "{test_suite} ({failure_summary})\n".format(test_suite=summary[1], failure_summary=summary[0]) + for summary in summary_matches: + msg += "{test_suite} ({total} total tests, {failure_summary})\n".format( + test_suite=summary[2], total=summary[0], failure_summary=summary[1] + ) if max_failed_tests == 0: raise EasyBuildError(msg) @@ -287,8 +337,12 @@ def test_step(self): "are known to be flaky, or do not affect your intended usage of PyTorch.", "In case of doubt, reach out to the EasyBuild community (via GitHub, Slack, or mailing list).", ]) + # Print to console, the user should really be aware that we are accepting failing tests here... print_warning(msg) + # Also log this warning in the file log + self.log.warning(msg) + if failed_test_cnt > max_failed_tests: raise EasyBuildError("Too many failed tests (%d), maximum allowed is %d", failed_test_cnt, max_failed_tests)