Skip to content
74 changes: 64 additions & 10 deletions easybuild/easyblocks/p/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from easybuild.tools.filetools import symlink, apply_regex_substitutions
from easybuild.tools.modules import get_software_root, get_software_version
from easybuild.tools.systemtools import POWER, get_cpu_architecture
from easybuild.tools.utilities import nub


class EB_PyTorch(PythonPackage):
Expand Down Expand Up @@ -264,17 +263,68 @@ def test_step(self):
(tests_out, tests_ec) = super(EB_PyTorch, self).test_step(return_output_ec=True)

ran_tests_hits = re.findall(r"^Ran (?P<test_cnt>[0-9]+) tests in", tests_out, re.M)
test_cnt = sum(int(hit) for hit in ran_tests_hits)

failed_tests = nub(re.findall(r"^(?P<test_name>.*) failed!(?: Received signal: \w+)?\s*$", tests_out, re.M))
failed_test_cnt = len(failed_tests)

if failed_test_cnt:
test_cnt = 0
for hit in ran_tests_hits:
test_cnt += int(hit)

# Get matches to create clear summary report, greps for patterns like:
# FAILED (errors=10, skipped=190, expected failures=6)
# test_fx failed!
regex = r"^Ran (?P<test_cnt>[0-9]+) tests.*$\n\nFAILED \((?P<failure_summary>.*)\)$\n(?:^(?:(?!failed!).)*$\n)*(?P<failed_test_suite_name>.*) failed!$" # noqa: E501
summary_matches = re.findall(regex, tests_out, re.M)

# Get matches to create clear summary report, greps for patterns like:
# ===================== 2 failed, 128 passed, 2 skipped, 2 warnings in 3.43s =====================
regex = r"^=+ (?P<failure_summary>.*) in [0-9]+\.*[0-9]*[a-zA-Z]* =+$\n(?P<failed_test_suite_name>.*) failed!$"
summary_matches_pattern2 = re.findall(regex, tests_out, re.M)

# Count failures and errors
def get_count_for_pattern(regex, text):
match = re.findall(regex, text, re.M)
if len(match) == 0:
return 0
elif len(match) == 1:
return int(match[0])
elif len(match) > 1:
# Shouldn't happen, but means something went wrong with the regular expressions.
# Throw warning, as the build might be fine, no need to error on this.
warn_msg = "Error in counting the number of test failures in the output of the PyTorch test suite.\n"
warn_msg += "Please check the EasyBuild log to verify the number of failures (if any) was acceptable."
print_warning(warn_msg)

failure_cnt = 0
error_cnt = 0
# Loop over first pattern to count failures/errors:
for summary in summary_matches:
failures = get_count_for_pattern(r"^.*(?<!expected )failures=(?P<failures>[0-9]+).*$", summary[1])
failure_cnt += failures
errs = get_count_for_pattern(r"^.*errors=(?P<errors>[0-9]+).*$", summary[1])
error_cnt += errs

# Loop over the second pattern to count failures/errors
for summary in summary_matches_pattern2:
failures = get_count_for_pattern(r"^.*(?P<failures>[0-9]+) failed.*$", summary[0])
failure_cnt += failures
errs = get_count_for_pattern(r"^.*(?P<errors>[0-9]+) error.*$", summary[0])
error_cnt += errs

# Calculate total number of unsuccesful tests
failed_test_cnt = failure_cnt + error_cnt

if failed_test_cnt > 0:
max_failed_tests = self.cfg['max_failed_tests']

test_or_tests = 'tests' if failed_test_cnt > 1 else 'test'
msg = "%d %s (out of %d) failed:\n" % (failed_test_cnt, test_or_tests, test_cnt)
msg += '\n'.join('* %s' % t for t in sorted(failed_tests))
failure_or_failures = 'failures' if failure_cnt > 1 else 'failure'
error_or_errors = 'errors' if error_cnt > 1 else 'error'
msg = "%d test %s, %d test %s (out of %d):\n" % (
failure_cnt, failure_or_failures, error_cnt, error_or_errors, test_cnt
)
for summary in summary_matches_pattern2:
msg += "{test_suite} ({failure_summary})\n".format(test_suite=summary[1], failure_summary=summary[0])
for summary in summary_matches:
msg += "{test_suite} ({total} total tests, {failure_summary})\n".format(
test_suite=summary[2], total=summary[0], failure_summary=summary[1]
)

if max_failed_tests == 0:
raise EasyBuildError(msg)
Expand All @@ -287,8 +337,12 @@ def test_step(self):
"are known to be flaky, or do not affect your intended usage of PyTorch.",
"In case of doubt, reach out to the EasyBuild community (via GitHub, Slack, or mailing list).",
])
# Print to console, the user should really be aware that we are accepting failing tests here...
print_warning(msg)

# Also log this warning in the file log
self.log.warning(msg)

if failed_test_cnt > max_failed_tests:
raise EasyBuildError("Too many failed tests (%d), maximum allowed is %d",
failed_test_cnt, max_failed_tests)
Expand Down