From 0b594894b42c1463a049cd1e43771e3c5fcacc6a Mon Sep 17 00:00:00 2001 From: Ojasv Kamal Date: Fri, 22 Jul 2022 11:09:59 +0530 Subject: [PATCH] Parametrized tests for test_mean_squared_error.py --- .../ignite/metrics/test_mean_squared_error.py | 67 +++++++++---------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/tests/ignite/metrics/test_mean_squared_error.py b/tests/ignite/metrics/test_mean_squared_error.py index a1df3fb3a5cf..434b4082669b 100644 --- a/tests/ignite/metrics/test_mean_squared_error.py +++ b/tests/ignite/metrics/test_mean_squared_error.py @@ -17,45 +17,40 @@ def test_zero_sample(): mse.compute() -def test_compute(): +@pytest.fixture(params=[item for item in range(4)]) +def test_case(request): + return [ + (torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1), + (torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1), + # updated batches + (torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16), + (torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case): mse = MeanSquaredError() - def _test(y_pred, y, batch_size): - mse.reset() - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - mse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - mse.update((y_pred, y)) - - np_y = y.numpy() - np_y_pred = y_pred.numpy() - - np_res = np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0] - - assert isinstance(mse.compute(), float) - assert mse.compute() == np_res - - def get_test_cases(): - - test_cases = [ - (torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1), - (torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1), - # updated batches - (torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16), - (torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + y_pred, y, batch_size = test_case + + mse.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + mse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + mse.update((y_pred, y)) + + np_y = y.numpy() + np_y_pred = y_pred.numpy() + + np_res = np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0] + + assert isinstance(mse.compute(), float) + assert mse.compute() == np_res def _test_distrib_integration(device, tol=1e-6):