Skip to content
5 changes: 5 additions & 0 deletions metrics/accuracy/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def _info(self):
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
Expand Down
5 changes: 5 additions & 0 deletions metrics/f1/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def _info(self):
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
Expand Down
5 changes: 5 additions & 0 deletions metrics/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def _info(self):
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
Expand Down
5 changes: 5 additions & 0 deletions metrics/recall/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def _info(self):
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
Expand Down
45 changes: 44 additions & 1 deletion tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from multiprocessing import Pool
from unittest import TestCase

from datasets.features import Features, Value
import pytest

from datasets.features import Features, Sequence, Value
from datasets.metric import Metric, MetricInfo

from .utils import require_tf, require_torch
Expand Down Expand Up @@ -469,3 +471,44 @@ def test_input_tf(self):
metric.add(prediction=pred, reference=ref)
self.assertDictEqual(expected_results, metric.compute())
del metric


class MetricWithMultiLabel(Metric):
def _info(self):
return MetricInfo(
description="dummy metric for tests",
citation="insert citation here",
features=Features(
{"predictions": Sequence(Value("int64")), "references": Sequence(Value("int64"))}
if self.config_name == "multilabel"
else {"predictions": Value("int64"), "references": Value("int64")}
),
)

def _compute(self, predictions=None, references=None):
return (
{
"accuracy": sum(i == j for i, j in zip(predictions, references)) / len(predictions),
}
if predictions
else {}
)


@pytest.mark.parametrize(
"config_name, predictions, references, expected",
[
(None, [1, 2, 3, 4], [1, 2, 4, 3], 0.5), # Multiclass: Value("int64")
(
"multilabel",
[[1, 0], [1, 0], [1, 0], [1, 0]],
[[1, 0], [0, 1], [1, 1], [0, 0]],
0.25,
), # Multilabel: Sequence(Value("int64"))
],
)
def test_metric_with_multilabel(config_name, predictions, references, expected, tmp_path):
cache_dir = tmp_path / "cache"
metric = MetricWithMultiLabel(config_name, cache_dir=cache_dir)
results = metric.compute(predictions=predictions, references=references)
assert results["accuracy"] == expected