Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 62 additions & 71 deletions tests/ignite/contrib/metrics/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,91 +63,82 @@ def test_check_shape():
ap._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1)))


def test_binary_and_multilabel_inputs():
@pytest.fixture(params=[item for item in range(8)])
def test_data_binary_and_multilabel(request):
return [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_binary_and_multilabel_inputs(n_times, test_data_binary_and_multilabel):
y_pred, y, batch_size = test_data_binary_and_multilabel
ap = AveragePrecision()
ap.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ap.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ap.update((y_pred, y))

def _test(y_pred, y, batch_size):
ap.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ap.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ap.update((y_pred, y))

np_y = y.numpy()
np_y_pred = y_pred.numpy()

res = ap.compute()
assert isinstance(res, float)
assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
]
np_y = y.numpy()
np_y_pred = y_pred.numpy()

return test_cases
res = ap.compute()
assert isinstance(res, float)
assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)

@pytest.fixture(params=[item for item in range(4)])
def test_data_integration_binary_and_multilabel(request):
return [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10),
(torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10),
(torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10),
][request.param]

def test_integration_binary_and_mulitlabel_inputs():
def _test(y_pred, y, batch_size):
def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)
@pytest.mark.parametrize("n_times", range(5))
def test_integration_binary_and_mulitlabel_inputs(n_times, test_data_integration_binary_and_multilabel):
y_pred, y, batch_size = test_data_integration_binary_and_multilabel

ap_metric = AveragePrecision()
ap_metric.attach(engine, "ap")
def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

np_y = y.numpy()
np_y_pred = y_pred.numpy()
engine = Engine(update_fn)

np_ap = average_precision_score(np_y, np_y_pred)
ap_metric = AveragePrecision()
ap_metric.attach(engine, "ap")

data = list(range(y_pred.shape[0] // batch_size))
ap = engine.run(data, max_epochs=1).metrics["ap"]
np_y = y.numpy()
np_y_pred = y_pred.numpy()

assert isinstance(ap, float)
assert np_ap == pytest.approx(ap)
np_ap = average_precision_score(np_y, np_y_pred)

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10),
(torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10),
(torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10),
]
return test_cases
data = list(range(y_pred.shape[0] // batch_size))
ap = engine.run(data, max_epochs=1).metrics["ap"]

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
assert isinstance(ap, float)
assert np_ap == pytest.approx(ap)


def _test_distrib_binary_and_multilabel_inputs(device):
Expand Down
118 changes: 55 additions & 63 deletions tests/ignite/contrib/metrics/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,43 +71,38 @@ def test_cohen_kappa_wrong_weights_type():
ck = CohenKappa(weights="dd")


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_binary_input(weights):
ck = CohenKappa(weights)

def _test(y_pred, y, batch_size):
ck.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ck.update((y_pred, y))
@pytest.fixture(params=range(4))
def test_data_binary(request):
return [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
][request.param]

np_y = y.numpy()
np_y_pred = y_pred.numpy()

res = ck.compute()
assert isinstance(res, float)
assert cohen_kappa_score(np_y, np_y_pred, weights=weights) == pytest.approx(res)
@pytest.mark.parametrize("n_times", range(5))
@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_binary_input(n_times, weights, test_data_binary):
y_pred, y, batch_size = test_data_binary
ck = CohenKappa(weights)
ck.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ck.update((y_pred, y))

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
]
return test_cases
np_y = y.numpy()
np_y_pred = y_pred.numpy()

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
res = ck.compute()
assert isinstance(res, float)
assert cohen_kappa_score(np_y, np_y_pred, weights=weights) == pytest.approx(res)


def test_multilabel_inputs():
Expand All @@ -129,44 +124,41 @@ def test_multilabel_inputs():
ck.compute()


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_integration_binary_input(weights):
def _test(y_pred, y, batch_size):
def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
@pytest.fixture(params=range(2))
def test_data_integration_binary(request):
return [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 10),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 10),
][request.param]

engine = Engine(update_fn)

ck_metric = CohenKappa(weights=weights)
ck_metric.attach(engine, "ck")
@pytest.mark.parametrize("n_times", range(5))
@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_integration_binary_input(n_times, weights, test_data_integration_binary):
y_pred, y, batch_size = test_data_integration_binary

np_y = y.numpy()
np_y_pred = y_pred.numpy()
def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

np_ck = cohen_kappa_score(np_y, np_y_pred, weights=weights)
engine = Engine(update_fn)

data = list(range(y_pred.shape[0] // batch_size))
ck = engine.run(data, max_epochs=1).metrics["ck"]
ck_metric = CohenKappa(weights=weights)
ck_metric.attach(engine, "ck")

assert isinstance(ck, float)
assert np_ck == pytest.approx(ck)
np_y = y.numpy()
np_y_pred = y_pred.numpy()

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 10),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 10),
]
return test_cases
np_ck = cohen_kappa_score(np_y, np_y_pred, weights=weights)

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
data = list(range(y_pred.shape[0] // batch_size))
ck = engine.run(data, max_epochs=1).metrics["ck"]

assert isinstance(ck, float)
assert np_ck == pytest.approx(ck)


def _test_distrib_binary_input(device):
Expand Down
Loading