Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach
from_csv, from_json, from_text,
prepare_for_task,
to_json,
align_labels_with_mapping

.. autofunction:: datasets.concatenate_datasets

Expand Down
36 changes: 36 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,42 @@ def add_item(self, item: dict, new_fingerprint: str):
fingerprint=new_fingerprint,
)

def align_labels_with_mapping(self, label2id: Dict, label_column: str) -> "Dataset":
"""Align the dataset's label ID and label name mapping to match an input :obj:`label2id` mapping.
This is useful when you want to ensure that a model's predicted labels are aligned with the dataset.
The alignment in done using the lowercase label names.

Args:
label2id (:obj:`dict`):
The label name to ID mapping to align the dataset with.
label_column (:obj:`str`):
The column name of labels to align on.

Example:
.. code-block:: python

# dataset with mapping {'entailment': 0, 'neutral': 1, 'contradiction': 2}
ds = load_dataset("glue", "mnli", split="train")
# mapping to align with
label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
ds_aligned = ds.align_labels_with_mapping(label2id, "label")
"""
features = self.features.copy()
int2str_function = features[label_column].int2str
# Sort input mapping by ID value to ensure the label names are aligned
label2id = dict(sorted(label2id.items(), key=lambda item: item[1]))
label_names = list(label2id.keys())
features[label_column] = ClassLabel(num_classes=len(label_names), names=label_names)
# Some label mappings use uppercase label names so we lowercase them during alignment
label2id = {k.lower(): v for k, v in label2id.items()}

def process_label_ids(batch):
dset_label_names = [int2str_function(label_id).lower() for label_id in batch[label_column]]
batch[label_column] = [label2id[label_name] for label_name in dset_label_names]
return batch

return self.map(process_label_ids, features=features, batched=True)


def concatenate_datasets(
dsets: List[Dataset],
Expand Down
19 changes: 19 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,25 @@ def test_task_templates_empty_after_preparation(self, in_memory):
with dset.prepare_for_task(task="text-classification") as dset:
self.assertIsNone(dset.info.task_templates)

def test_align_labels_with_mapping(self, in_memory):
features = Features(
{
"input_text": Value("string"),
"input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]),
}
)
data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]}
label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1}
id2label = {v: k for k, v in label2id.items()}
expected_labels = [2, 2, 1, 1, 0, 0]
expected_label_names = [id2label[idx] for idx in expected_labels]
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, features=features) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
self.assertListEqual(expected_labels, dset["input_labels"])
aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]]
self.assertListEqual(expected_label_names, aligned_label_names)


class MiscellaneousDatasetTest(TestCase):
def test_from_pandas(self):
Expand Down