Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach
description, download_checksums, download_size, features, homepage,
license, size_in_bytes, supervised_keys, version,
from_csv, from_json, from_text,
prepare_for_task,
prepare_for_task, align_labels_with_mapping,
to_json,
align_labels_with_mapping

.. autofunction:: datasets.concatenate_datasets

Expand All @@ -57,7 +56,8 @@ It also has dataset transform methods like map or filter, to process all the spl
flatten_, cast_, remove_columns_, rename_column_,
flatten, cast, remove_columns, rename_column, class_encode_column,
save_to_disk, load_from_disk,
from_csv, from_json, from_text, prepare_for_task
from_csv, from_json, from_text,
prepare_for_task, align_labels_with_mapping


``Features``
Expand Down
34 changes: 22 additions & 12 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def flatten_(self, max_depth=16):
for dataset in self.values():
dataset.flatten_(max_depth=max_depth)

def flatten(self, max_depth=16):
def flatten(self, max_depth=16) -> "DatasetDict":
"""Flatten the Apache Arrow Table of each split (nested features are flatten).
Each column with a struct type is flattened into one column per struct field.
Other columns are left unchanged.
Expand Down Expand Up @@ -166,7 +166,7 @@ def cast_(self, features: Features):
new_dataset_dict = {k: dataset.cast(features=features) for k, dataset in self.items()}
self.update(new_dataset_dict)

def cast(self, features: Features):
def cast(self, features: Features) -> "DatasetDict":
"""
Cast the dataset to a new set of features.
The transformation is applied to all the datasets of the dataset dictionary.
Expand Down Expand Up @@ -197,7 +197,7 @@ def remove_columns_(self, column_names: Union[str, List[str]]):
new_dataset_dict = {k: dataset.remove_columns(column_names=column_names) for k, dataset in self.items()}
self.update(new_dataset_dict)

def remove_columns(self, column_names: Union[str, List[str]]):
def remove_columns(self, column_names: Union[str, List[str]]) -> "DatasetDict":
"""
Remove one or several column(s) from each split in the dataset
and the features associated to the column(s).
Expand Down Expand Up @@ -231,7 +231,7 @@ def rename_column_(self, original_column_name: str, new_column_name: str):
}
self.update(new_dataset_dict)

def rename_column(self, original_column_name: str, new_column_name: str):
def rename_column(self, original_column_name: str, new_column_name: str) -> "DatasetDict":
"""
Rename a column in the dataset and move the features associated to the original column under the new column name.
The transformation is applied to all the datasets of the dataset dictionary.
Expand All @@ -252,7 +252,7 @@ def rename_column(self, original_column_name: str, new_column_name: str):
}
)

def class_encode_column(self, column: str):
def class_encode_column(self, column: str) -> "DatasetDict":
"""Casts the given column as :obj:``datasets.features.ClassLabel`` and updates the tables.

Args:
Expand Down Expand Up @@ -361,7 +361,7 @@ def with_format(
columns: Optional[List] = None,
output_all_columns: bool = False,
**format_kwargs,
):
) -> "DatasetDict":
"""Set __getitem__ return format (type and columns). The data formatting is applied on-the-fly.
The format ``type`` (for example "numpy") is used to format batches when using __getitem__.
The format is set for every dataset in the dataset dictionary
Expand All @@ -388,7 +388,7 @@ def with_transform(
transform: Optional[Callable],
columns: Optional[List] = None,
output_all_columns: bool = False,
):
) -> "DatasetDict":
"""Set __getitem__ return format using this transform. The transform is applied on-the-fly on batches when __getitem__ is called.
The transform is set for every dataset in the dataset dictionary

Expand Down Expand Up @@ -615,7 +615,7 @@ def shuffle(
load_from_cache_file: bool = True,
indices_cache_file_names: Optional[Dict[str, Optional[str]]] = None,
writer_batch_size: Optional[int] = 1000,
):
) -> "DatasetDict":
"""Create a new Dataset where the rows are shuffled.

The transformation is applied to all the datasets of the dataset dictionary.
Expand Down Expand Up @@ -742,7 +742,7 @@ def from_csv(
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
) -> "DatasetDict":
"""Create DatasetDict from CSV file(s).

Args:
Expand All @@ -769,7 +769,7 @@ def from_json(
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
) -> "DatasetDict":
"""Create DatasetDict from JSON Lines file(s).

Args:
Expand All @@ -796,7 +796,7 @@ def from_text(
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
) -> "DatasetDict":
"""Create DatasetDict from text file(s).

Args:
Expand All @@ -817,6 +817,16 @@ def from_text(
).read()

@is_documented_by(Dataset.prepare_for_task)
def prepare_for_task(self, task: Union[str, TaskTemplate]):
def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "DatasetDict":
self._check_values_type()
return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()})

@is_documented_by(Dataset.align_labels_with_mapping)
def align_labels_with_mapping(self, label2id: Dict, label_column: str) -> "DatasetDict":
self._check_values_type()
return DatasetDict(
{
k: dataset.align_labels_with_mapping(label2id=label2id, label_column=label_column)
for k, dataset in self.items()
}
)
41 changes: 40 additions & 1 deletion tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datasets import load_from_disk
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict
from datasets.features import Features, Sequence, Value
from datasets.features import ClassLabel, Features, Sequence, Value
from datasets.splits import NamedSplit

from .conftest import s3_test_bucket_name
Expand Down Expand Up @@ -435,6 +435,45 @@ def test_load_from_disk(self):
self.assertListEqual(dsets["test"].column_names, ["filename"])
del dsets

def test_align_labels_with_mapping(self):
train_features = Features(
{
"input_text": Value("string"),
"input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]),
}
)
test_features = Features(
{
"input_text": Value("string"),
"input_labels": ClassLabel(num_classes=3, names=["entailment", "contradiction", "neutral"]),
}
)
train_data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]}
test_data = {"input_text": ["a", "a", "c", "c", "b", "b"], "input_labels": [0, 0, 1, 1, 2, 2]}
label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1}
id2label = {v: k for k, v in label2id.items()}
train_expected_labels = [2, 2, 1, 1, 0, 0]
test_expected_labels = [2, 2, 0, 0, 1, 1]
train_expected_label_names = [id2label[idx] for idx in train_expected_labels]
test_expected_label_names = [id2label[idx] for idx in test_expected_labels]
dsets = DatasetDict(
{
"train": Dataset.from_dict(train_data, features=train_features),
"test": Dataset.from_dict(test_data, features=test_features),
}
)
dsets = dsets.align_labels_with_mapping(label2id, "input_labels")
self.assertListEqual(train_expected_labels, dsets["train"]["input_labels"])
self.assertListEqual(test_expected_labels, dsets["test"]["input_labels"])
train_aligned_label_names = [
dsets["train"].features["input_labels"].int2str(idx) for idx in dsets["train"]["input_labels"]
]
test_aligned_label_names = [
dsets["test"].features["input_labels"].int2str(idx) for idx in dsets["test"]["input_labels"]
]
self.assertListEqual(train_expected_label_names, train_aligned_label_names)
self.assertListEqual(test_expected_label_names, test_aligned_label_names)


def _check_csv_datasetdict(dataset_dict, expected_features, splits=("train",)):
assert isinstance(dataset_dict, DatasetDict)
Expand Down