diff --git a/t5/evaluation/metrics.py b/t5/evaluation/metrics.py index fb9df802..62bd2e62 100644 --- a/t5/evaluation/metrics.py +++ b/t5/evaluation/metrics.py @@ -24,20 +24,28 @@ import itertools import re import string -from typing import Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union from absl import logging import editdistance +import flax +import jax.numpy as jnp import numpy as np import sacrebleu import scipy.stats +import seqio import sklearn.metrics from t5.evaluation import qa_utils +import tensorflow.compat.v2 as tf from rouge_score import rouge_scorer from rouge_score import scoring +ModelOutputType = seqio.metrics.ModelOutputType +CollectingMetric = seqio.metrics.CollectingMetric + + def bleu(targets, predictions, tokenizer="intl"): """Computes BLEU score. @@ -643,3 +651,103 @@ def edit_distance(targets, predictions, lower=True): "mean_edit": np.mean(edit_distances), "median_edit": np.median(edit_distances), "sum_edit": sum(edit_distances)} + + +@flax.struct.dataclass +class ShardedSquad(seqio.metrics.Metric): + """Implements SQuAD metrics, maximizing over answers per question.""" + + f1: float = 0.0 + em: float = 0.0 + count: int = 0 + model_output_type: ModelOutputType = ModelOutputType.PREDICTION + + @classmethod + def empty(cls) -> "ShardedSquad": + return cls(f1=0.0, em=0.0, count=0) + + @classmethod + def from_model_output( + cls, + inputs: Sequence[Mapping[str, Any]], + model_output: np.ndarray, + features: Mapping[str, seqio.Feature], + target_field_name: str = "targets", + mask: Optional[np.ndarray] = None, + indices_2d: Optional[np.ndarray] = None) -> "ShardedSquad": + + del indices_2d + if mask is None: + mask = jnp.ones((len(inputs),)) + + # Postprocesses the targets here. + postprocessed_targets = [[ + tf.compat.as_text(answers) for answers in example["answers"] + ] for example, included in zip(inputs, mask) if included] + + # Decodes the predictions here. + vocab = features[target_field_name].vocabulary + predictions = [ + vocab.decode(tokens) + for tokens, included in zip(model_output, mask) + if included + ] + + squad_result = squad(targets=postprocessed_targets, predictions=predictions) + return cls(f1=squad_result["f1"], em=squad_result["em"], count=mask.sum()) + + def merge(self, other: "ShardedSquad") -> "ShardedSquad": + """Returns `Squad` that is the accumulation of `self` and `other`. + + Args: + other: A `Squad` whose inermediate values should be accumulated onto the + values of `self`. Note that in a distributed setting, `other` will + typically be the output of a `jax.lax` parallel operator and thus have a + dimension added to the dataclass returned by `.from_model_output()`. + + Returns: + A new `Squad` that accumulates the value from both `self` and `other`. + """ + count = self.count + other.count + f1 = (self.f1 * self.count + other.f1 * other.count)/count + em = (self.em * self.count + other.em * other.count)/count + + return type(self)(f1=f1, em=em, count=count) + + def compute(self): + return {"f1": self.f1, "em": self.em} + + +@flax.struct.dataclass +class PassthroughSquad(CollectingMetric): + """Implements SQuAD metrics, maximizing over answers per question.""" + + model_output_type: ModelOutputType = ModelOutputType.PREDICTION + + def actual_compute(self, task_dataset_as_numpy, task_output_features, + target_field_name: str = "targets"): + # Postprocesses the targets here. + postprocessed_targets = [[ + tf.compat.as_text(answers) for answers in example["answers"] + ] for example in task_dataset_as_numpy] + + # We process the model outputs here by the steps below. + # Step 1: removes padded examples using mask. + indices_2d = self.values["indices_2d"][self.values["mask"] == 1] + model_output = self.values["model_output"][self.values["mask"] == 1] + assert len(postprocessed_targets) == len(indices_2d) + + # Step 2: sorts the model outputs by 2d-indices, namely (shard_id, + # index_within_shard) to align with targets. + permutation = np.lexsort((indices_2d[:, 1], indices_2d[:, 0])) + model_output = [ + model_output[permutation[i]] for i in range(len(permutation)) + ] + + # Decodes the predictions here. + target_vocab = task_output_features[target_field_name].vocabulary + predictions = [ + target_vocab.decode(tokens) for tokens in model_output + ] + + return squad(postprocessed_targets, predictions), None diff --git a/t5/evaluation/metrics_test.py b/t5/evaluation/metrics_test.py index 469f321a..8a8f800c 100644 --- a/t5/evaluation/metrics_test.py +++ b/t5/evaluation/metrics_test.py @@ -14,9 +14,12 @@ """Tests for t5.evaluation.metrics.""" +from unittest import mock + from absl.testing import absltest +import numpy as np +import seqio import sklearn.metrics - from t5.evaluation import metrics from t5.evaluation import test_utils @@ -706,5 +709,213 @@ def test_edit_distance(self): }) +def mock_decode(self, ids): + decode_dict = {v: k for k, v in self._encode_dict.items()} + words = [decode_dict[token] for token in ids if token != 0] + return " ".join(words) + + +class PassthroughSquadTest(test_utils.BaseMetricsTest): + + def test_same(self): + ref = "this is a string" + inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "this": 2, + "is": 3, + "a": 4, + "string": 5 + }, vocab_size=10) + + model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.PassthroughSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.actual_compute(inputs, features)[0], + {"em": 100, "f1": 100}) + + def test_different(self): + ref = "this is a string" + inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "this": 2, + "is": 3, + "a": 4, + "string": 5, + "": 6 + }, vocab_size=10) + + model_output = np.array([[6], [6]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.PassthroughSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.actual_compute(inputs, features)[0], + {"em": 0, "f1": 0}) + + def test_big(self): + inputs = [ + {"answers": ["big moose", "hippo"]}, + {"answers": ["correct1"]}, + {"answers": ["correct2.1", "correct2.2"]}, + {"answers": ["a", "b"]}, + ] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "‘a": 2, + "big": 3, + "Moose!‘": 4, + "wrong": 5, + "correct2.2": 6, + "c": 7 + }, vocab_size=10) + + model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.PassthroughSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.actual_compute(inputs, features)[0], + {"em": 25., "f1": 35.}, places=2) + + def test_small(self): + inputs = [{"answers": ["abc abd", "$$$$"]}] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10) + + model_output = np.array([[2]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.PassthroughSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.actual_compute(inputs, features)[0], + {"f1": 100 * 2.0 / 3.0, "em": 0.}) + + +class ShardedSquadTest(test_utils.BaseMetricsTest): + + def test_same(self): + ref = "this is a string" + inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "this": 2, + "is": 3, + "a": 4, + "string": 5 + }, vocab_size=10) + + model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.ShardedSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.compute(), {"em": 100, "f1": 100}) + + def test_different(self): + ref = "this is a string" + inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "this": 2, + "is": 3, + "a": 4, + "string": 5, + "": 6 + }, vocab_size=10) + + model_output = np.array([[6], [6]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.ShardedSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.compute(), {"em": 0, "f1": 0}) + + def test_big(self): + inputs = [ + {"answers": ["big moose", "hippo"]}, + {"answers": ["correct1"]}, + {"answers": ["correct2.1", "correct2.2"]}, + {"answers": ["a", "b"]}, + ] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "‘a": 2, + "big": 3, + "Moose!‘": 4, + "wrong": 5, + "correct2.2": 6, + "c": 7 + }, vocab_size=10) + + model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.ShardedSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2) + + def test_small(self): + inputs = [{"answers": ["abc abd", "$$$$"]}] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10) + + model_output = np.array([[2]]) + features = {"targets": seqio.Feature(vocabulary)} + metric = metrics.ShardedSquad.from_model_output( + inputs, model_output, features) + self.assertDictClose(metric.compute(), {"f1": 100 * 2.0 / 3.0, "em": 0.}) + + def test_batch_update(self): + inputs1 = [ + {"answers": ["big moose", "hippo"]}, + {"answers": ["correct1"]} + ] + inputs2 = [ + {"answers": ["correct2.1", "correct2.2"]}, + {"answers": ["a", "b"]}, + ] + + with mock.patch.object( + seqio.test_utils.MockVocabulary, "decode", new=mock_decode): + vocabulary = seqio.test_utils.MockVocabulary( + { + "‘a": 2, + "big": 3, + "Moose!‘": 4, + "wrong": 5, + "correct2.2": 6, + "c": 7 + }, vocab_size=10) + + model_output1 = np.array([[2, 3, 4], [5, 0, 0]]) + model_output2 = np.array([[6], [7]]) + features = {"targets": seqio.Feature(vocabulary)} + metric1 = metrics.ShardedSquad.from_model_output( + inputs1, model_output1, features) + metric2 = metrics.ShardedSquad.from_model_output( + inputs2, model_output2, features) + metric = metric1.merge(metric2) + self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2) + + if __name__ == "__main__": absltest.main()