Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach
from_csv, from_json, from_text,
prepare_for_task,
to_json,
align_labels_with_mapping

.. autofunction:: datasets.concatenate_datasets

Expand Down
35 changes: 35 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,41 @@ def add_item(self, item: dict, new_fingerprint: str):
fingerprint=new_fingerprint,
)

def align_labels_with_mapping(self, label2id: Dict, label_column: str) -> "Dataset":
"""Align the dataset's label ID and label name mappings to match an input :obj:`label2id` mapping.
This is useful when you want to ensure that a model's predicted labels are aligned with the evaluation dataset.

Args:
label2id (:obj:`dict`):
The label name to ID mapping to align the dataset with.
label_column (:obj:`str`):
The column name of labels to align on.

Example:
.. code-block:: python

# dataset with mapping {'entailment': 0, 'neutral': 1, 'contradiction': 2}
ds = load_dataset("glue", "mnli", split="train")
# mapping to align with
label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
ds_aligned = ds.align_labels_with_mapping(label2id, "label")
"""
features = self.features.copy()
int2str_function = features[label_column].int2str
# Sort input mapping by ID value to ensure the label names are aligned
label2id = dict(sorted(label2id.items(), key=lambda item: item[1]))
label_names = list(label2id.keys())
features[label_column] = ClassLabel(num_classes=len(label_names), names=label_names)
# Some label mappings use uppercase label names so we lowercase them during alignment
label2id = {k.lower(): v for k, v in label2id.items()}

def process_label_ids(batch):
dset_label_names = [int2str_function(label_id).lower() for label_id in batch[label_column]]
batch[label_column] = [label2id[label_name] for label_name in dset_label_names]
return batch

return self.map(process_label_ids, features=features, batched=True)


def concatenate_datasets(
dsets: List[Dataset],
Expand Down
19 changes: 19 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,25 @@ def test_task_templates_empty_after_preparation(self, in_memory):
with dset.prepare_for_task(task="text-classification") as dset:
self.assertIsNone(dset.info.task_templates)

def test_align_labels_with_mapping(self, in_memory):
features = Features(
{
"input_text": Value("string"),
"input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]),
}
)
data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]}
label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1}
id2label = {v: k for k, v in label2id.items()}
expected_labels = [2, 2, 1, 1, 0, 0]
expected_label_names = [id2label[idx] for idx in expected_labels]
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, features=features) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
self.assertListEqual(expected_labels, dset["input_labels"])
aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]]
self.assertListEqual(expected_label_names, aligned_label_names)


class MiscellaneousDatasetTest(TestCase):
def test_from_pandas(self):
Expand Down