diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 0ea5f0ba42d..f49afb9a929 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -39,6 +39,8 @@ from multiprocess import Pool, RLock from tqdm.auto import tqdm +from datasets.tasks.text_classification import TextClassification + from . import config from .arrow_reader import ArrowReader from .arrow_writer import ArrowWriter, OptimizedTypedSequence @@ -1391,7 +1393,7 @@ def with_transform( def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset": """Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`. - Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. + Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. Intended for single-use only, so all task templates are removed from :attr:`datasets.DatasetInfo.task_templates` after casting. Args: task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include: @@ -1419,10 +1421,18 @@ def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset": raise ValueError( f"Expected a `str` or `datasets.tasks.TaskTemplate` object but got task {task} with type {type(task)}." ) + if isinstance(template, TextClassification) and self.info.features is not None: + dataset_labels = tuple(sorted(self.info.features[template.label_column].names)) + if template.labels is None or template.labels != dataset_labels: + raise ValueError( + f"Incompatible labels between the dataset and task template! Expected labels {dataset_labels} but got {template.labels}. Please ensure that `datasets.tasks.TextClassification.labels` matches the features of the dataset." + ) column_mapping = template.column_mapping columns_to_drop = [column for column in self.column_names if column not in column_mapping] dataset = self.remove_columns(columns_to_drop) dataset = dataset.rename_columns(column_mapping) + # We found a template so now flush `DatasetInfo` to skip the template update in `DatasetInfo.__post_init__` + dataset.info.task_templates = None dataset = dataset.cast(features=template.features) return dataset diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 3c97aa67e88..4dd57ca2da2 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -9,6 +9,8 @@ import fsspec import numpy as np +from datasets.utils.doc_utils import is_documented_by + from .arrow_dataset import Dataset from .features import Features from .filesystems import extract_path_from_uri, is_remote_filesystem @@ -793,18 +795,7 @@ def from_text( path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ).read() + @is_documented_by(Dataset.prepare_for_task) def prepare_for_task(self, task: Union[str, TaskTemplate]): - """Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`. - - Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. - - Args: - task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include: - - - :obj:`"text-classification"` - - :obj:`"question-answering"` - - If :obj:`TaskTemplate`, must be one of the task templates in :py:mod:`datasets.tasks`. - """ self._check_values_type() return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()}) diff --git a/src/datasets/info.py b/src/datasets/info.py index 6b65f294716..0d1931d8bb1 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -36,6 +36,8 @@ from dataclasses import asdict, dataclass, field from typing import List, Optional, Union +from datasets.tasks.text_classification import TextClassification + from . import config from .features import Features, Value from .splits import SplitDict @@ -154,6 +156,7 @@ def __post_init__(self): else: self.supervised_keys = SupervisedKeysData(**self.supervised_keys) + # Parse and make a list of templates if self.task_templates is not None: if isinstance(self.task_templates, (list, tuple)): templates = [ @@ -167,6 +170,17 @@ def __post_init__(self): template = task_template_from_dict(self.task_templates) self.task_templates = [template] if template is not None else [] + # Insert labels and mappings for text classification + if self.task_templates is not None: + self.task_templates = list(self.task_templates) + if self.features is not None: + for idx, template in enumerate(self.task_templates): + if isinstance(template, TextClassification): + labels = self.features[template.label_column].names + self.task_templates[idx] = TextClassification( + text_column=template.text_column, label_column=template.label_column, labels=labels + ) + def _license_path(self, dataset_info_dir): return os.path.join(dataset_info_dir, config.LICENSE_FILENAME) diff --git a/src/datasets/tasks/__init__.py b/src/datasets/tasks/__init__.py index 1591fd318fa..da86ad103d3 100644 --- a/src/datasets/tasks/__init__.py +++ b/src/datasets/tasks/__init__.py @@ -1,5 +1,6 @@ from typing import Optional +from ..utils.logging import get_logger from .base import TaskTemplate from .question_answering import QuestionAnswering from .text_classification import TextClassification @@ -7,6 +8,8 @@ __all__ = ["TaskTemplate", "QuestionAnswering", "TextClassification"] +logger = get_logger(__name__) + NAME2TEMPLATE = {QuestionAnswering.task: QuestionAnswering, TextClassification.task: TextClassification} @@ -15,8 +18,7 @@ def task_template_from_dict(task_template_dict: dict) -> Optional[TaskTemplate]: """Create one of the supported task templates in :py:mod:`datasets.tasks` from a dictionary.""" task_name = task_template_dict.get("task") if task_name is None: + logger.warning(f"Couldn't find template for task '{task_name}'. Available templates: {list(NAME2TEMPLATE)}") return None template = NAME2TEMPLATE.get(task_name) - if template is None: - return None return template.from_dict(task_template_dict) diff --git a/src/datasets/tasks/base.py b/src/datasets/tasks/base.py index 13d3c8ab279..f75d2e4d228 100644 --- a/src/datasets/tasks/base.py +++ b/src/datasets/tasks/base.py @@ -11,7 +11,8 @@ @dataclass(frozen=True) class TaskTemplate(abc.ABC): - task: ClassVar[str] + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str input_schema: ClassVar[Features] label_schema: ClassVar[Features] diff --git a/src/datasets/tasks/question_answering.py b/src/datasets/tasks/question_answering.py index 145be969633..b62fdd64d03 100644 --- a/src/datasets/tasks/question_answering.py +++ b/src/datasets/tasks/question_answering.py @@ -7,7 +7,8 @@ @dataclass(frozen=True) class QuestionAnswering(TaskTemplate): - task: ClassVar[str] = "question-answering" + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str = "question-answering" input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")}) label_schema: ClassVar[Features] = Features( { diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 140e7526ec5..e9002024768 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -1,40 +1,28 @@ from dataclasses import dataclass -from typing import ClassVar, Dict, List +from typing import ClassVar, Dict, Optional, Tuple from ..features import ClassLabel, Features, Value from .base import TaskTemplate -class FeaturesWithLazyClassLabel: - def __init__(self, features, label_column="labels"): - assert label_column in features, f"Key '{label_column}' missing in features {features}" - self._features = features - self._label_column = label_column - - def __get__(self, obj, objtype=None): - if obj is None: - return self._features - - assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'" - features = self._features.copy() - features["labels"] = ClassLabel(names=getattr(obj, self._label_column)) - return features - - @dataclass(frozen=True) class TextClassification(TaskTemplate): - task: ClassVar[str] = "text-classification" + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str = "text-classification" input_schema: ClassVar[Features] = Features({"text": Value("string")}) # TODO(lewtun): Find a more elegant approach without descriptors. - label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) - labels: List[str] + label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) text_column: str = "text" label_column: str = "labels" + labels: Optional[Tuple[str]] = None def __post_init__(self): - assert len(self.labels) == len(set(self.labels)), "Labels must be unique" - # Cast labels to tuple to allow hashing - self.__dict__["labels"] = tuple(sorted(self.labels)) + if self.labels: + assert len(self.labels) == len(set(self.labels)), "Labels must be unique" + # Cast labels to tuple to allow hashing + self.__dict__["labels"] = tuple(sorted(self.labels)) + self.__dict__["label_schema"] = self.label_schema.copy() + self.label_schema["labels"] = ClassLabel(names=self.labels) @property def column_mapping(self) -> Dict[str, str]: @@ -42,11 +30,3 @@ def column_mapping(self) -> Dict[str, str]: self.text_column: "text", self.label_column: "labels", } - - @property - def label2id(self): - return {label: idx for idx, label in enumerate(self.labels)} - - @property - def id2label(self): - return {idx: label for idx, label in enumerate(self.labels)} diff --git a/src/datasets/utils/doc_utils.py b/src/datasets/utils/doc_utils.py new file mode 100644 index 00000000000..6ef8bcb4e70 --- /dev/null +++ b/src/datasets/utils/doc_utils.py @@ -0,0 +1,15 @@ +from typing import Callable + + +def is_documented_by(function_with_docstring: Callable): + """Decorator to share docstrings across common functions. + + Args: + function_with_docstring (`Callable`): Name of the function with the docstring. + """ + + def wrapper(target_function): + target_function.__doc__ = function_with_docstring.__doc__ + return target_function + + return wrapper diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 0bd295dc9d3..24a2bf783d1 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -702,8 +702,13 @@ def test_concatenate_with_no_task_templates(self, in_memory): self.assertEqual(dset_concat.info.task_templates, None) def test_concatenate_with_equal_task_templates(self, in_memory): - task_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"]) - info = DatasetInfo(task_templates=task_template) + labels = ["neg", "pos"] + task_template = TextClassification(text_column="text", label_column="labels", labels=labels) + info = DatasetInfo( + features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}), + # Label names are added in `DatasetInfo.__post_init__` so not included here + task_templates=TextClassification(text_column="text", label_column="labels"), + ) data = {"text": ["i love transformers!"], "labels": [1]} with tempfile.TemporaryDirectory() as tmp_dir: with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict( @@ -716,10 +721,42 @@ def test_concatenate_with_equal_task_templates(self, in_memory): self.assertListEqual(dset_concat.info.task_templates, [task_template]) def test_concatenate_with_mixed_task_templates_in_common(self, in_memory): - tc_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"]) + tc_template = TextClassification(text_column="text", label_column="labels") qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers") - info1 = DatasetInfo(task_templates=[qa_template]) - info2 = DatasetInfo(task_templates=[qa_template, tc_template]) + info1 = DatasetInfo( + task_templates=[qa_template], + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + ) + info2 = DatasetInfo( + task_templates=[qa_template, tc_template], + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + ) data = { "text": ["i love transformers!"], "labels": [1], @@ -738,15 +775,67 @@ def test_concatenate_with_mixed_task_templates_in_common(self, in_memory): self.assertListEqual(dset_concat.info.task_templates, [qa_template]) def test_concatenate_with_no_mixed_task_templates_in_common(self, in_memory): - tc_template1 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"]) - tc_template2 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg", "neutral"]) + tc_template1 = TextClassification(text_column="text", label_column="labels") + tc_template2 = TextClassification(text_column="text", label_column="sentiment") qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers") - info1 = DatasetInfo(task_templates=[tc_template1]) - info2 = DatasetInfo(task_templates=[tc_template2]) - info3 = DatasetInfo(task_templates=[qa_template]) + info1 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[tc_template1], + ) + info2 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[tc_template2], + ) + info3 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[qa_template], + ) data = { "text": ["i love transformers!"], "labels": [1], + "sentiment": [0], "context": ["huggingface is going to the moon!"], "question": ["where is huggingface going?"], "answers": [{"text": ["to the moon!"], "answer_start": [2]}], @@ -1929,34 +2018,50 @@ def test_task_text_classification(self, in_memory): features_before_cast = Features( { "input_text": Value("string"), - "input_labels": Value("int32"), + "input_labels": ClassLabel(names=labels), } ) - # Labels are cast to tuple during TextClassification init, so we do the same here + # Labels are cast to tuple during `TextClassification.__post_init_`, so we do the same here features_after_cast = Features( { "text": Value("string"), "labels": ClassLabel(names=tuple(labels)), } ) - task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) - info = DatasetInfo( + # Label names are added in `DatasetInfo.__post_init__` so not needed here + task_without_labels = TextClassification(text_column="input_text", label_column="input_labels") + info1 = DatasetInfo( features=features_before_cast, - task_templates=task, + task_templates=task_without_labels, + ) + # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task` + # However they also can be used to define `DatasetInfo` so we include a test for this too + task_with_labels = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + info2 = DatasetInfo( + features=features_before_cast, + task_templates=task_with_labels, ) data = {"input_text": ["i love transformers!"], "input_labels": [1]} - # Test we can load from task name - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: + # Test we can load from task name when label names not included in template (default behaviour) + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info1) as dset: with self._to(in_memory, tmp_dir, dset) as dset: self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) self.assertDictEqual(features_before_cast, dset.features) with dset.prepare_for_task(task="text-classification") as dset: self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from TaskTemplate - info.task_templates = None - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: - with dset.prepare_for_task(task=task) as dset: + # Test we can load from task name when label names included in template + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info2) as dset: + with self._to(in_memory, tmp_dir, dset) as dset: + self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) + self.assertDictEqual(features_before_cast, dset.features) + with dset.prepare_for_task(task="text-classification") as dset: + self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) + self.assertDictEqual(features_after_cast, dset.features) + # Test we can load from TextClassification template + info1.task_templates = None + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info1) as dset: + with dset.prepare_for_task(task=task_with_labels) as dset: self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) self.assertDictEqual(features_after_cast, dset.features) @@ -2008,7 +2113,7 @@ def test_task_question_answering(self, in_memory): set(dset.flatten().column_names), ) self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from TaskTemplate + # Test we can load from QuestionAnswering template info.task_templates = None with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: with dset.prepare_for_task(task=task) as dset: @@ -2030,10 +2135,10 @@ def test_task_with_incompatible_templates(self, in_memory): features = Features( { "input_text": Value("string"), - "input_labels": Value("int32"), + "input_labels": ClassLabel(names=labels), } ) - task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + task = TextClassification(text_column="input_text", label_column="input_labels") info = DatasetInfo( features=features, task_templates=task, @@ -2041,14 +2146,40 @@ def test_task_with_incompatible_templates(self, in_memory): data = {"input_text": ["i love transformers!"], "input_labels": [1]} with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: with self._to(in_memory, tmp_dir, dset) as dset: - with self.assertRaises(ValueError): - # Invalid task name - dset.prepare_for_task("this-task-does-not-exist") - # Duplicate task templates - dset.info.task_templates = [task, task] - dset.prepare_for_task("text-classification") - # Invalid task type - dset.prepare_for_task(1) + # Invalid task name + self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist") + # Invalid task templates with incompatible labels + task_with_wrong_labels = TextClassification( + text_column="input_text", label_column="input_labels", labels=["neut"] + ) + self.assertRaises(ValueError, dset.prepare_for_task, task_with_wrong_labels) + task_with_no_labels = TextClassification( + text_column="input_text", label_column="input_labels", labels=None + ) + self.assertRaises(ValueError, dset.prepare_for_task, task_with_no_labels) + # Duplicate task templates + dset.info.task_templates = [task, task] + self.assertRaises(ValueError, dset.prepare_for_task, "text-classification") + # Invalid task type + self.assertRaises(ValueError, dset.prepare_for_task, 1) + + def test_task_templates_empty_after_preparation(self, in_memory): + features = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(names=["pos", "neg"]), + } + ) + task = TextClassification(text_column="input_text", label_column="input_labels") + info = DatasetInfo( + features=features, + task_templates=task, + ) + data = {"input_text": ["i love transformers!"], "input_labels": [1]} + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: + with self._to(in_memory, tmp_dir, dset) as dset: + with dset.prepare_for_task(task="text-classification") as dset: + self.assertIsNone(dset.info.task_templates) class MiscellaneousDatasetTest(TestCase): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 7a2fe205f3a..65e3b3ebb7f 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -14,7 +14,7 @@ def test_column_mapping(self): def test_from_dict(self): input_schema = Features({"text": Value("string")}) - # Labels are cast to tuple during TextClassification init, so we do the same here + # Labels are cast to tuple during `TextClassification.__post_init__`, so we do the same here label_schema = Features({"labels": ClassLabel(names=tuple(self.labels))}) template_dict = {"text_column": "input_text", "label_column": "input_labels", "labels": self.labels} task = TextClassification.from_dict(template_dict)