diff --git a/metrics/accuracy/accuracy.py b/metrics/accuracy/accuracy.py index c24edae16de..c050338e331 100644 --- a/metrics/accuracy/accuracy.py +++ b/metrics/accuracy/accuracy.py @@ -69,6 +69,11 @@ def _info(self): inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { + "predictions": datasets.Sequence(datasets.Value("int32")), + "references": datasets.Sequence(datasets.Value("int32")), + } + if self.config_name == "multilabel" + else { "predictions": datasets.Value("int32"), "references": datasets.Value("int32"), } diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index eae87799780..d95f8c6553d 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -85,6 +85,11 @@ def _info(self): inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { + "predictions": datasets.Sequence(datasets.Value("int32")), + "references": datasets.Sequence(datasets.Value("int32")), + } + if self.config_name == "multilabel" + else { "predictions": datasets.Value("int32"), "references": datasets.Value("int32"), } diff --git a/metrics/precision/precision.py b/metrics/precision/precision.py index a3b0ddb54ef..a13ead33ce3 100644 --- a/metrics/precision/precision.py +++ b/metrics/precision/precision.py @@ -87,6 +87,11 @@ def _info(self): inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { + "predictions": datasets.Sequence(datasets.Value("int32")), + "references": datasets.Sequence(datasets.Value("int32")), + } + if self.config_name == "multilabel" + else { "predictions": datasets.Value("int32"), "references": datasets.Value("int32"), } diff --git a/metrics/recall/recall.py b/metrics/recall/recall.py index fed703537aa..30c77a95464 100644 --- a/metrics/recall/recall.py +++ b/metrics/recall/recall.py @@ -87,6 +87,11 @@ def _info(self): inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { + "predictions": datasets.Sequence(datasets.Value("int32")), + "references": datasets.Sequence(datasets.Value("int32")), + } + if self.config_name == "multilabel" + else { "predictions": datasets.Value("int32"), "references": datasets.Value("int32"), } diff --git a/tests/test_metric.py b/tests/test_metric.py index d48f57ba348..819f35ccc25 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -5,7 +5,9 @@ from multiprocessing import Pool from unittest import TestCase -from datasets.features import Features, Value +import pytest + +from datasets.features import Features, Sequence, Value from datasets.metric import Metric, MetricInfo from .utils import require_tf, require_torch @@ -469,3 +471,44 @@ def test_input_tf(self): metric.add(prediction=pred, reference=ref) self.assertDictEqual(expected_results, metric.compute()) del metric + + +class MetricWithMultiLabel(Metric): + def _info(self): + return MetricInfo( + description="dummy metric for tests", + citation="insert citation here", + features=Features( + {"predictions": Sequence(Value("int64")), "references": Sequence(Value("int64"))} + if self.config_name == "multilabel" + else {"predictions": Value("int64"), "references": Value("int64")} + ), + ) + + def _compute(self, predictions=None, references=None): + return ( + { + "accuracy": sum(i == j for i, j in zip(predictions, references)) / len(predictions), + } + if predictions + else {} + ) + + +@pytest.mark.parametrize( + "config_name, predictions, references, expected", + [ + (None, [1, 2, 3, 4], [1, 2, 4, 3], 0.5), # Multiclass: Value("int64") + ( + "multilabel", + [[1, 0], [1, 0], [1, 0], [1, 0]], + [[1, 0], [0, 1], [1, 1], [0, 0]], + 0.25, + ), # Multilabel: Sequence(Value("int64")) + ], +) +def test_metric_with_multilabel(config_name, predictions, references, expected, tmp_path): + cache_dir = tmp_path / "cache" + metric = MetricWithMultiLabel(config_name, cache_dir=cache_dir) + results = metric.compute(predictions=predictions, references=references) + assert results["accuracy"] == expected