diff --git a/docs/source/package_reference/main_classes.rst b/docs/source/package_reference/main_classes.rst index 458aa2cc857..12df09c08bb 100644 --- a/docs/source/package_reference/main_classes.rst +++ b/docs/source/package_reference/main_classes.rst @@ -35,6 +35,7 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach from_csv, from_json, from_text, prepare_for_task, to_json, + align_labels_with_mapping .. autofunction:: datasets.concatenate_datasets diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index dfad4b883bc..060baf0447c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3097,6 +3097,42 @@ def add_item(self, item: dict, new_fingerprint: str): fingerprint=new_fingerprint, ) + def align_labels_with_mapping(self, label2id: Dict, label_column: str) -> "Dataset": + """Align the dataset's label ID and label name mapping to match an input :obj:`label2id` mapping. + This is useful when you want to ensure that a model's predicted labels are aligned with the dataset. + The alignment in done using the lowercase label names. + + Args: + label2id (:obj:`dict`): + The label name to ID mapping to align the dataset with. + label_column (:obj:`str`): + The column name of labels to align on. + + Example: + .. code-block:: python + + # dataset with mapping {'entailment': 0, 'neutral': 1, 'contradiction': 2} + ds = load_dataset("glue", "mnli", split="train") + # mapping to align with + label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2} + ds_aligned = ds.align_labels_with_mapping(label2id, "label") + """ + features = self.features.copy() + int2str_function = features[label_column].int2str + # Sort input mapping by ID value to ensure the label names are aligned + label2id = dict(sorted(label2id.items(), key=lambda item: item[1])) + label_names = list(label2id.keys()) + features[label_column] = ClassLabel(num_classes=len(label_names), names=label_names) + # Some label mappings use uppercase label names so we lowercase them during alignment + label2id = {k.lower(): v for k, v in label2id.items()} + + def process_label_ids(batch): + dset_label_names = [int2str_function(label_id).lower() for label_id in batch[label_column]] + batch[label_column] = [label2id[label_name] for label_name in dset_label_names] + return batch + + return self.map(process_label_ids, features=features, batched=True) + def concatenate_datasets( dsets: List[Dataset], diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 9b36a5c2da2..78ad8681a7f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2202,6 +2202,25 @@ def test_task_templates_empty_after_preparation(self, in_memory): with dset.prepare_for_task(task="text-classification") as dset: self.assertIsNone(dset.info.task_templates) + def test_align_labels_with_mapping(self, in_memory): + features = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]), + } + ) + data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]} + label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1} + id2label = {v: k for k, v in label2id.items()} + expected_labels = [2, 2, 1, 1, 0, 0] + expected_label_names = [id2label[idx] for idx in expected_labels] + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, features=features) as dset: + with self._to(in_memory, tmp_dir, dset) as dset: + with dset.align_labels_with_mapping(label2id, "input_labels") as dset: + self.assertListEqual(expected_labels, dset["input_labels"]) + aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]] + self.assertListEqual(expected_label_names, aligned_label_names) + class MiscellaneousDatasetTest(TestCase): def test_from_pandas(self):