From 538f3be410febb6382d6cf005f6eacd4ac364e97 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 21 May 2021 17:06:55 +0200 Subject: [PATCH 01/25] Update labels in DatasetInfo __post_init__ --- src/datasets/info.py | 13 ++++++++++++- src/datasets/tasks/text_classification.py | 17 ++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 8e323bd4c33..3fd3527ce54 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -36,8 +36,10 @@ from dataclasses import asdict, dataclass, field from typing import List, Optional, Union +from datasets.tasks.text_classification import TextClassification + from . import config -from .features import Features, Value +from .features import Features, Value, ClassLabel from .splits import SplitDict from .tasks import TaskTemplate, task_template_from_dict from .utils import Version @@ -165,6 +167,15 @@ def __post_init__(self): else: template = task_template_from_dict(self.task_templates) self.task_templates = [template] if template is not None else [] + # insert labels and mappings for text classification + for idx, template in enumerate(self.task_templates): + if isinstance(template, TextClassification) and self.features is not None: + # This introduces state and raises a KeyError when we call Dataset.prepare_for_task :( + # The reason is that Dataset.prepare_for_task calls Dataset.cast which converts the + # DatasetInfo.features to the new schema and thus template.label_column is no longer a valid key + object.__setattr__(template, "labels", tuple(self.features[template.label_column].names)) + template.label_schema["labels"] = ClassLabel(names=template.labels) + self.task_templates[idx] = template def _license_path(self, dataset_info_dir): return os.path.join(dataset_info_dir, config.LICENSE_FILENAME) diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 04cc26d78ef..4f727afb299 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -12,19 +12,14 @@ class TextClassification(TaskTemplate): # TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so # investigate if there's a more elegant approach. label_schema = Features({"labels": ClassLabel}) - labels: List[str] + labels: List[str] = None text_column: str = "text" label_column: str = "labels" def __post_init__(self): - assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique" - # Cast labels to tuple to allow hashing - object.__setattr__(self, "labels", tuple(sorted(self.labels))) + object.__setattr__(self, "labels", self.labels) object.__setattr__(self, "text_column", self.text_column) object.__setattr__(self, "label_column", self.label_column) - self.label_schema["labels"] = ClassLabel(names=self.labels) - object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)}) - object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()}) @property def column_mapping(self) -> Dict[str, str]: @@ -40,3 +35,11 @@ def from_dict(cls, template_dict: dict) -> "TextClassification": label_column=template_dict["label_column"], labels=template_dict["labels"], ) + + @property + def label2id(self): + return {label: idx for idx, label in enumerate(self.labels)} + + @property + def id2label(self): + return {idx: label for idx, label in enumerate(self.labels)} From c02a2e4a36f369a0230fa91fe8e63a71692260b6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 21 May 2021 17:30:45 +0200 Subject: [PATCH 02/25] Add emotion example --- datasets/emotion/emotion.py | 7 +++++++ src/datasets/info.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/datasets/emotion/emotion.py b/datasets/emotion/emotion.py index 56f05160b7d..7bdfda80cab 100644 --- a/datasets/emotion/emotion.py +++ b/datasets/emotion/emotion.py @@ -1,6 +1,7 @@ import csv import datasets +from datasets.tasks.text_classification import TextClassification _CITATION = """\ @@ -44,6 +45,12 @@ def _info(self): supervised_keys=("text", "label"), homepage=_URL, citation=_CITATION, + task_templates=[ + TextClassification( + text_column="text", + label_column="label", + ) + ], ) def _split_generators(self, dl_manager): diff --git a/src/datasets/info.py b/src/datasets/info.py index 3fd3527ce54..17548373777 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -39,7 +39,7 @@ from datasets.tasks.text_classification import TextClassification from . import config -from .features import Features, Value, ClassLabel +from .features import ClassLabel, Features, Value from .splits import SplitDict from .tasks import TaskTemplate, task_template_from_dict from .utils import Version From 2eab30c1bec05eb5083dd5ca491644b13b199303 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sun, 23 May 2021 10:16:50 +0200 Subject: [PATCH 03/25] Flush task templates before casting --- src/datasets/arrow_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 59c011b79d0..b57302e4621 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1420,6 +1420,8 @@ def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset": columns_to_drop = [column for column in self.column_names if column not in column_mapping] dataset = self.remove_columns(columns_to_drop) dataset = dataset.rename_columns(column_mapping) + # We found a template so now flush `DatasetInfo` to skip the template update in `DatasetInfo.__post_init__` + dataset.info.task_templates = None dataset = dataset.cast(features=template.features) return dataset From feaca484ffc50b7b5c6001d30e2eb46ab06b49f9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sun, 23 May 2021 11:01:32 +0200 Subject: [PATCH 04/25] Add labels to TextClassification __post_init__ --- src/datasets/tasks/text_classification.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 4f727afb299..8afe86aafd8 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -17,9 +17,11 @@ class TextClassification(TaskTemplate): label_column: str = "labels" def __post_init__(self): - object.__setattr__(self, "labels", self.labels) object.__setattr__(self, "text_column", self.text_column) object.__setattr__(self, "label_column", self.label_column) + if self.labels: + object.__setattr__(self, "labels", tuple(sorted(self.labels))) + self.label_schema["labels"] = ClassLabel(names=self.labels) @property def column_mapping(self) -> Dict[str, str]: From 188d02cec3fb41dcbdbde73b0e6b64a28fc97a2d Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sun, 23 May 2021 11:02:10 +0200 Subject: [PATCH 05/25] Add comment about casting to tuple --- src/datasets/info.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 17548373777..7ebc039692d 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -170,10 +170,8 @@ def __post_init__(self): # insert labels and mappings for text classification for idx, template in enumerate(self.task_templates): if isinstance(template, TextClassification) and self.features is not None: - # This introduces state and raises a KeyError when we call Dataset.prepare_for_task :( - # The reason is that Dataset.prepare_for_task calls Dataset.cast which converts the - # DatasetInfo.features to the new schema and thus template.label_column is no longer a valid key - object.__setattr__(template, "labels", tuple(self.features[template.label_column].names)) + # Cast labels to tuple to enable hashing of task template + object.__setattr__(template, "labels", tuple(sorted(self.features[template.label_column].names))) template.label_schema["labels"] = ClassLabel(names=template.labels) self.task_templates[idx] = template From 1e3e830415549f3b0080fd4cef45989f55936a3a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sun, 23 May 2021 11:04:23 +0200 Subject: [PATCH 06/25] Fix capitalisation --- src/datasets/info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 7ebc039692d..ac07ca1c451 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -167,7 +167,7 @@ def __post_init__(self): else: template = task_template_from_dict(self.task_templates) self.task_templates = [template] if template is not None else [] - # insert labels and mappings for text classification + # Insert labels and mappings for text classification for idx, template in enumerate(self.task_templates): if isinstance(template, TextClassification) and self.features is not None: # Cast labels to tuple to enable hashing of task template From 635e54dcba26c7e9baa9653db8f9de147c0dcf14 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sun, 23 May 2021 11:05:13 +0200 Subject: [PATCH 07/25] Refactor tests to account for label update in `DatasetInfo`, add test --- tests/test_arrow_dataset.py | 133 ++++++++++++++++++++++++++++++++---- tests/test_tasks.py | 2 +- 2 files changed, 120 insertions(+), 15 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 0bd295dc9d3..68ea8da0214 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -702,8 +702,11 @@ def test_concatenate_with_no_task_templates(self, in_memory): self.assertEqual(dset_concat.info.task_templates, None) def test_concatenate_with_equal_task_templates(self, in_memory): - task_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"]) - info = DatasetInfo(task_templates=task_template) + task_template = TextClassification(text_column="text", label_column="labels") + info = DatasetInfo( + features=Features({"text": Value("string"), "labels": ClassLabel(names=["pos", "neg"])}), + task_templates=task_template, + ) data = {"text": ["i love transformers!"], "labels": [1]} with tempfile.TemporaryDirectory() as tmp_dir: with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict( @@ -716,10 +719,42 @@ def test_concatenate_with_equal_task_templates(self, in_memory): self.assertListEqual(dset_concat.info.task_templates, [task_template]) def test_concatenate_with_mixed_task_templates_in_common(self, in_memory): - tc_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"]) + tc_template = TextClassification(text_column="text", label_column="labels") qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers") - info1 = DatasetInfo(task_templates=[qa_template]) - info2 = DatasetInfo(task_templates=[qa_template, tc_template]) + info1 = DatasetInfo( + task_templates=[qa_template], + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + ) + info2 = DatasetInfo( + task_templates=[qa_template, tc_template], + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + ) data = { "text": ["i love transformers!"], "labels": [1], @@ -738,15 +773,67 @@ def test_concatenate_with_mixed_task_templates_in_common(self, in_memory): self.assertListEqual(dset_concat.info.task_templates, [qa_template]) def test_concatenate_with_no_mixed_task_templates_in_common(self, in_memory): - tc_template1 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"]) - tc_template2 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg", "neutral"]) + tc_template1 = TextClassification(text_column="text", label_column="labels") + tc_template2 = TextClassification(text_column="text", label_column="sentiment") qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers") - info1 = DatasetInfo(task_templates=[tc_template1]) - info2 = DatasetInfo(task_templates=[tc_template2]) - info3 = DatasetInfo(task_templates=[qa_template]) + info1 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[tc_template1], + ) + info2 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[tc_template2], + ) + info3 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[qa_template], + ) data = { "text": ["i love transformers!"], "labels": [1], + "sentiment": [0], "context": ["huggingface is going to the moon!"], "question": ["where is huggingface going?"], "answers": [{"text": ["to the moon!"], "answer_start": [2]}], @@ -1929,7 +2016,7 @@ def test_task_text_classification(self, in_memory): features_before_cast = Features( { "input_text": Value("string"), - "input_labels": Value("int32"), + "input_labels": ClassLabel(names=labels), } ) # Labels are cast to tuple during TextClassification init, so we do the same here @@ -1939,7 +2026,7 @@ def test_task_text_classification(self, in_memory): "labels": ClassLabel(names=tuple(labels)), } ) - task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + task = TextClassification(text_column="input_text", label_column="input_labels") info = DatasetInfo( features=features_before_cast, task_templates=task, @@ -2030,10 +2117,10 @@ def test_task_with_incompatible_templates(self, in_memory): features = Features( { "input_text": Value("string"), - "input_labels": Value("int32"), + "input_labels": ClassLabel(names=labels), } ) - task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + task = TextClassification(text_column="input_text", label_column="input_labels") info = DatasetInfo( features=features, task_templates=task, @@ -2050,6 +2137,24 @@ def test_task_with_incompatible_templates(self, in_memory): # Invalid task type dset.prepare_for_task(1) + def test_task_templates_empty_after_preparation(self, in_memory): + features = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(names=["pos", "neg"]), + } + ) + task = TextClassification(text_column="input_text", label_column="input_labels") + info = DatasetInfo( + features=features, + task_templates=task, + ) + data = {"input_text": ["i love transformers!"], "input_labels": [1]} + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: + with self._to(in_memory, tmp_dir, dset) as dset: + with dset.prepare_for_task(task="text-classification") as dset: + self.assertIsNone(dset.info.task_templates) + class MiscellaneousDatasetTest(TestCase): def test_from_pandas(self): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 7a2fe205f3a..bdb5413ec4a 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -14,7 +14,7 @@ def test_column_mapping(self): def test_from_dict(self): input_schema = Features({"text": Value("string")}) - # Labels are cast to tuple during TextClassification init, so we do the same here + # Labels are cast to tuple during TextClassification __post_init__, so we do the same here label_schema = Features({"labels": ClassLabel(names=tuple(self.labels))}) template_dict = {"text_column": "input_text", "label_column": "input_labels", "labels": self.labels} task = TextClassification.from_dict(template_dict) From 43f9d55977810427b0f7cb996551f527a6c443e7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 25 May 2021 17:50:56 +0200 Subject: [PATCH 08/25] Update label schema in post_init --- src/datasets/tasks/text_classification.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 140e7526ec5..6e0c2ee612d 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -27,14 +27,16 @@ class TextClassification(TaskTemplate): input_schema: ClassVar[Features] = Features({"text": Value("string")}) # TODO(lewtun): Find a more elegant approach without descriptors. label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) - labels: List[str] + labels: List[str] = None text_column: str = "text" label_column: str = "labels" def __post_init__(self): - assert len(self.labels) == len(set(self.labels)), "Labels must be unique" - # Cast labels to tuple to allow hashing - self.__dict__["labels"] = tuple(sorted(self.labels)) + if self.labels: + assert len(self.labels) == len(set(self.labels)), "Labels must be unique" + # Cast labels to tuple to allow hashing + self.__dict__["labels"] = tuple(sorted(self.labels)) + self.label_schema["labels"] = ClassLabel(names=self.labels) @property def column_mapping(self) -> Dict[str, str]: From 5d66b4f03bf629e77c9fc7cc5d15b09fba551634 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 25 May 2021 17:55:23 +0200 Subject: [PATCH 09/25] Use __dict__ instead of __setattr__ to update task template labels --- src/datasets/info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 52cd3a2eee4..df1377e0c6c 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -172,7 +172,7 @@ def __post_init__(self): for idx, template in enumerate(self.task_templates): if isinstance(template, TextClassification) and self.features is not None: # Cast labels to tuple to enable hashing of task template - object.__setattr__(template, "labels", tuple(sorted(self.features[template.label_column].names))) + template.__dict__["labels"] = tuple(sorted(self.features[template.label_column].names)) template.label_schema["labels"] = ClassLabel(names=template.labels) self.task_templates[idx] = template From e7b1f7a863b2f309e9335de7418edbdbb5d618b9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 25 May 2021 22:31:24 +0200 Subject: [PATCH 10/25] Raise ValueError if TextClassification template has None or incompatible labels --- src/datasets/arrow_dataset.py | 8 ++++++++ tests/test_arrow_dataset.py | 24 ++++++++++++++++-------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index b57302e4621..08f9f30464e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -39,6 +39,8 @@ from multiprocess import Pool, RLock from tqdm.auto import tqdm +from datasets.tasks.text_classification import TextClassification + from . import config from .arrow_reader import ArrowReader from .arrow_writer import ArrowWriter, OptimizedTypedSequence @@ -1416,6 +1418,12 @@ def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset": raise ValueError( f"Expected a `str` or `datasets.tasks.TaskTemplate` object but got task {task} with type {type(task)}." ) + if isinstance(template, TextClassification) and self.info.features is not None: + dataset_labels = tuple(sorted(self.info.features[template.label_column].names)) + if template.labels is None or template.labels != dataset_labels: + raise ValueError( + f"Incompatible labels between the dataset and task template! Expected labels {dataset_labels} but got {template.labels}. Please ensure that `datasets.tasks.TextClassification.labels` matches the features of the dataset." + ) column_mapping = template.column_mapping columns_to_drop = [column for column in self.column_names if column not in column_mapping] dataset = self.remove_columns(columns_to_drop) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 68ea8da0214..e6b1dfa7faa 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2128,14 +2128,22 @@ def test_task_with_incompatible_templates(self, in_memory): data = {"input_text": ["i love transformers!"], "input_labels": [1]} with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: with self._to(in_memory, tmp_dir, dset) as dset: - with self.assertRaises(ValueError): - # Invalid task name - dset.prepare_for_task("this-task-does-not-exist") - # Duplicate task templates - dset.info.task_templates = [task, task] - dset.prepare_for_task("text-classification") - # Invalid task type - dset.prepare_for_task(1) + # Invalid task name + self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist") + # Invalid task templates with incompatible labels + task_with_wrong_labels = TextClassification( + text_column="input_text", label_column="input_labels", labels=["neut"] + ) + self.assertRaises(ValueError, dset.prepare_for_task, task_with_wrong_labels) + task_with_no_labels = TextClassification( + text_column="input_text", label_column="input_labels", labels=None + ) + self.assertRaises(ValueError, dset.prepare_for_task, task_with_no_labels) + # Duplicate task templates + dset.info.task_templates = [task, task] + self.assertRaises(ValueError, dset.prepare_for_task, "text-classification") + # Invalid task type + self.assertRaises(ValueError, dset.prepare_for_task, 1) def test_task_templates_empty_after_preparation(self, in_memory): features = Features( From 6f3ff6d3c715157db06b8cbc431469a5047b08a2 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 25 May 2021 22:38:21 +0200 Subject: [PATCH 11/25] Remove task templates from emotion demo --- datasets/emotion/emotion.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/datasets/emotion/emotion.py b/datasets/emotion/emotion.py index 7bdfda80cab..56f05160b7d 100644 --- a/datasets/emotion/emotion.py +++ b/datasets/emotion/emotion.py @@ -1,7 +1,6 @@ import csv import datasets -from datasets.tasks.text_classification import TextClassification _CITATION = """\ @@ -45,12 +44,6 @@ def _info(self): supervised_keys=("text", "label"), homepage=_URL, citation=_CITATION, - task_templates=[ - TextClassification( - text_column="text", - label_column="label", - ) - ], ) def _split_generators(self, dl_manager): From 1bf0b5bc42f66624aaf601e55c3d519113701f9a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 26 May 2021 10:36:17 +0200 Subject: [PATCH 12/25] Add decorator to share docstrings across multiple functions --- src/datasets/dataset_dict.py | 15 +++------------ src/datasets/utils/doc_utils.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 12 deletions(-) create mode 100644 src/datasets/utils/doc_utils.py diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 0223704ac75..0d46bc3d801 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -9,6 +9,8 @@ import fsspec import numpy as np +from datasets.utils.doc_utils import is_documented_by + from .arrow_dataset import Dataset from .features import Features from .filesystems import extract_path_from_uri, is_remote_filesystem @@ -792,18 +794,7 @@ def from_text( path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ).read() + @is_documented_by(Dataset.prepare_for_task) def prepare_for_task(self, task: Union[str, TaskTemplate]): - """Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`. - - Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. - - Args: - task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include: - - - :obj:`"text-classification"` - - :obj:`"question-answering"` - - If :obj:`TaskTemplate`, must be one of the task templates in :py:mod:`datasets.tasks`. - """ self._check_values_type() return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()}) diff --git a/src/datasets/utils/doc_utils.py b/src/datasets/utils/doc_utils.py new file mode 100644 index 00000000000..6ef8bcb4e70 --- /dev/null +++ b/src/datasets/utils/doc_utils.py @@ -0,0 +1,15 @@ +from typing import Callable + + +def is_documented_by(function_with_docstring: Callable): + """Decorator to share docstrings across common functions. + + Args: + function_with_docstring (`Callable`): Name of the function with the docstring. + """ + + def wrapper(target_function): + target_function.__doc__ = function_with_docstring.__doc__ + return target_function + + return wrapper From 0dda59e02431da4d047ff259948542c17b5818e7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 26 May 2021 10:42:35 +0200 Subject: [PATCH 13/25] Update docstring for prepare_for_task --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 08f9f30464e..eae646859fd 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1390,7 +1390,7 @@ def with_transform( def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset": """Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`. - Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. + Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. Intended for single-use only, so all task templates are removed from :attr:`datasets.DatasetInfo.task_templates` after casting. Args: task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include: From 654b2b0e9ffddb8c08ccf51b76b289a7b56a408e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 26 May 2021 10:43:16 +0200 Subject: [PATCH 14/25] Reorder TextClassification args for better intuition --- src/datasets/tasks/text_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 6e0c2ee612d..4742fb4d42a 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -27,9 +27,9 @@ class TextClassification(TaskTemplate): input_schema: ClassVar[Features] = Features({"text": Value("string")}) # TODO(lewtun): Find a more elegant approach without descriptors. label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) - labels: List[str] = None text_column: str = "text" label_column: str = "labels" + labels: List[str] = None def __post_init__(self): if self.labels: From 812bd8790fbf83c9315450f066a667e41bdbc073 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 26 May 2021 16:45:29 +0200 Subject: [PATCH 15/25] fix missing "task" field in json + edit copy of objects instead of modifying in-place --- src/datasets/info.py | 14 +++++++++----- src/datasets/tasks/__init__.py | 6 ++++-- src/datasets/tasks/base.py | 3 ++- src/datasets/tasks/question_answering.py | 3 ++- src/datasets/tasks/text_classification.py | 17 ++++++----------- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index df1377e0c6c..86f84b8deb2 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -156,6 +156,7 @@ def __post_init__(self): else: self.supervised_keys = SupervisedKeysData(**self.supervised_keys) + # Parse and make a list of templates if self.task_templates is not None: if isinstance(self.task_templates, (list, tuple)): templates = [ @@ -168,13 +169,16 @@ def __post_init__(self): else: template = task_template_from_dict(self.task_templates) self.task_templates = [template] if template is not None else [] - # Insert labels and mappings for text classification + + # Insert labels and mappings for text classification + if self.task_templates is not None: + self.task_templates = list(self.task_templates) for idx, template in enumerate(self.task_templates): if isinstance(template, TextClassification) and self.features is not None: - # Cast labels to tuple to enable hashing of task template - template.__dict__["labels"] = tuple(sorted(self.features[template.label_column].names)) - template.label_schema["labels"] = ClassLabel(names=template.labels) - self.task_templates[idx] = template + labels = self.features[template.label_column].names + self.task_templates[idx] = TextClassification( + text_column=template.text_column, label_column=template.label_column, labels=labels + ) def _license_path(self, dataset_info_dir): return os.path.join(dataset_info_dir, config.LICENSE_FILENAME) diff --git a/src/datasets/tasks/__init__.py b/src/datasets/tasks/__init__.py index 1591fd318fa..e1fe44a04c2 100644 --- a/src/datasets/tasks/__init__.py +++ b/src/datasets/tasks/__init__.py @@ -1,5 +1,6 @@ from typing import Optional +from ..utils.logging import get_logger from .base import TaskTemplate from .question_answering import QuestionAnswering from .text_classification import TextClassification @@ -7,6 +8,8 @@ __all__ = ["TaskTemplate", "QuestionAnswering", "TextClassification"] +logger = get_logger(__name__) + NAME2TEMPLATE = {QuestionAnswering.task: QuestionAnswering, TextClassification.task: TextClassification} @@ -15,8 +18,7 @@ def task_template_from_dict(task_template_dict: dict) -> Optional[TaskTemplate]: """Create one of the supported task templates in :py:mod:`datasets.tasks` from a dictionary.""" task_name = task_template_dict.get("task") if task_name is None: + logger.warning(f"Couldn't find template for rasl '{task_name}'. Available templates: {list(NAME2TEMPLATE)}") return None template = NAME2TEMPLATE.get(task_name) - if template is None: - return None return template.from_dict(task_template_dict) diff --git a/src/datasets/tasks/base.py b/src/datasets/tasks/base.py index 13d3c8ab279..f75d2e4d228 100644 --- a/src/datasets/tasks/base.py +++ b/src/datasets/tasks/base.py @@ -11,7 +11,8 @@ @dataclass(frozen=True) class TaskTemplate(abc.ABC): - task: ClassVar[str] + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str input_schema: ClassVar[Features] label_schema: ClassVar[Features] diff --git a/src/datasets/tasks/question_answering.py b/src/datasets/tasks/question_answering.py index 145be969633..b62fdd64d03 100644 --- a/src/datasets/tasks/question_answering.py +++ b/src/datasets/tasks/question_answering.py @@ -7,7 +7,8 @@ @dataclass(frozen=True) class QuestionAnswering(TaskTemplate): - task: ClassVar[str] = "question-answering" + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str = "question-answering" input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")}) label_schema: ClassVar[Features] = Features( { diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 4742fb4d42a..8608b1ce7a5 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -1,5 +1,6 @@ +import copy from dataclasses import dataclass -from typing import ClassVar, Dict, List +from typing import ClassVar, Dict, Optional, Tuple from ..features import ClassLabel, Features, Value from .base import TaskTemplate @@ -23,19 +24,21 @@ def __get__(self, obj, objtype=None): @dataclass(frozen=True) class TextClassification(TaskTemplate): - task: ClassVar[str] = "text-classification" + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str = "text-classification" input_schema: ClassVar[Features] = Features({"text": Value("string")}) # TODO(lewtun): Find a more elegant approach without descriptors. label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) text_column: str = "text" label_column: str = "labels" - labels: List[str] = None + labels: Optional[Tuple[str]] = None def __post_init__(self): if self.labels: assert len(self.labels) == len(set(self.labels)), "Labels must be unique" # Cast labels to tuple to allow hashing self.__dict__["labels"] = tuple(sorted(self.labels)) + self.__dict__["label_schema"] = copy.deepcopy(self.label_schema) self.label_schema["labels"] = ClassLabel(names=self.labels) @property @@ -44,11 +47,3 @@ def column_mapping(self) -> Dict[str, str]: self.text_column: "text", self.label_column: "labels", } - - @property - def label2id(self): - return {label: idx for idx, label in enumerate(self.labels)} - - @property - def id2label(self): - return {idx: label for idx, label in enumerate(self.labels)} From 159a6f62974ded9c313e34af64a1368559fb1b36 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 26 May 2021 16:50:14 +0200 Subject: [PATCH 16/25] style --- src/datasets/info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 86f84b8deb2..8a3fbf6ff35 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -39,7 +39,7 @@ from datasets.tasks.text_classification import TextClassification from . import config -from .features import ClassLabel, Features, Value +from .features import Features, Value from .splits import SplitDict from .tasks import TaskTemplate, task_template_from_dict from .utils import Version From a580339f7f28740706a8216eaf6f5551004ce1c6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 26 May 2021 17:34:35 +0200 Subject: [PATCH 17/25] Fix failing tests due to new DatasetInfo.__post_init__ --- tests/test_arrow_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index e6b1dfa7faa..fa10beb4655 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -702,9 +702,10 @@ def test_concatenate_with_no_task_templates(self, in_memory): self.assertEqual(dset_concat.info.task_templates, None) def test_concatenate_with_equal_task_templates(self, in_memory): - task_template = TextClassification(text_column="text", label_column="labels") + labels = ["neg", "pos"] + task_template = TextClassification(text_column="text", label_column="labels", labels=labels) info = DatasetInfo( - features=Features({"text": Value("string"), "labels": ClassLabel(names=["pos", "neg"])}), + features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}), task_templates=task_template, ) data = {"text": ["i love transformers!"], "labels": [1]} @@ -2026,7 +2027,7 @@ def test_task_text_classification(self, in_memory): "labels": ClassLabel(names=tuple(labels)), } ) - task = TextClassification(text_column="input_text", label_column="input_labels") + task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) info = DatasetInfo( features=features_before_cast, task_templates=task, From cff9d52eaccf3ba8e71ee1c0f4b16f12068cc12d Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 27 May 2021 11:11:39 +0200 Subject: [PATCH 18/25] Refactor TextClassification test to cover templates w / w-out labels --- tests/test_arrow_dataset.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index fa10beb4655..b0e6708320f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2027,10 +2027,11 @@ def test_task_text_classification(self, in_memory): "labels": ClassLabel(names=tuple(labels)), } ) - task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + # Label names are added in `DatasetInfo.__post_init__` so not needed here + task_without_labels = TextClassification(text_column="input_text", label_column="input_labels") info = DatasetInfo( features=features_before_cast, - task_templates=task, + task_templates=task_without_labels, ) data = {"input_text": ["i love transformers!"], "input_labels": [1]} # Test we can load from task name @@ -2041,10 +2042,12 @@ def test_task_text_classification(self, in_memory): with dset.prepare_for_task(task="text-classification") as dset: self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from TaskTemplate + # Test we can load from TextClassification template info.task_templates = None + # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task` + task_with_labels = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: - with dset.prepare_for_task(task=task) as dset: + with dset.prepare_for_task(task=task_with_labels) as dset: self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) self.assertDictEqual(features_after_cast, dset.features) @@ -2096,7 +2099,7 @@ def test_task_question_answering(self, in_memory): set(dset.flatten().column_names), ) self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from TaskTemplate + # Test we can load from QuestionAnswering template info.task_templates = None with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: with dset.prepare_for_task(task=task) as dset: From 81468678fbc0edece0c4a23716129d79590526dd Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 27 May 2021 11:18:09 +0200 Subject: [PATCH 19/25] Refactor use of label names in task template concatenation test --- tests/test_arrow_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index b0e6708320f..8d14282d8bf 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -706,7 +706,8 @@ def test_concatenate_with_equal_task_templates(self, in_memory): task_template = TextClassification(text_column="text", label_column="labels", labels=labels) info = DatasetInfo( features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}), - task_templates=task_template, + # Label names are added in `DatasetInfo.__post_init__` so not included here + task_templates=TextClassification(text_column="text", label_column="labels"), ) data = {"text": ["i love transformers!"], "labels": [1]} with tempfile.TemporaryDirectory() as tmp_dir: @@ -2020,7 +2021,7 @@ def test_task_text_classification(self, in_memory): "input_labels": ClassLabel(names=labels), } ) - # Labels are cast to tuple during TextClassification init, so we do the same here + # Labels are cast to tuple during `TextClassification.__init_`, so we do the same here features_after_cast = Features( { "text": Value("string"), From fa53dc50baec00d48497078ee60b031a638cf96a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 27 May 2021 11:25:53 +0200 Subject: [PATCH 20/25] Add separate test for template with labels in DatasetInfo --- tests/test_arrow_dataset.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 8d14282d8bf..e2e014015af 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2030,13 +2030,28 @@ def test_task_text_classification(self, in_memory): ) # Label names are added in `DatasetInfo.__post_init__` so not needed here task_without_labels = TextClassification(text_column="input_text", label_column="input_labels") - info = DatasetInfo( + info1 = DatasetInfo( features=features_before_cast, task_templates=task_without_labels, ) + # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task` + # However they also can be used to define `DatasetInfo` so we include a test for this too + task_with_labels = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + info2 = DatasetInfo( + features=features_before_cast, + task_templates=task_with_labels, + ) data = {"input_text": ["i love transformers!"], "input_labels": [1]} - # Test we can load from task name - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: + # Test we can load from task name when label names not included in template (default behaviour) + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info1) as dset: + with self._to(in_memory, tmp_dir, dset) as dset: + self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) + self.assertDictEqual(features_before_cast, dset.features) + with dset.prepare_for_task(task="text-classification") as dset: + self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) + self.assertDictEqual(features_after_cast, dset.features) + # Test we can load from task name when label names included in template + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info2) as dset: with self._to(in_memory, tmp_dir, dset) as dset: self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) self.assertDictEqual(features_before_cast, dset.features) @@ -2044,10 +2059,8 @@ def test_task_text_classification(self, in_memory): self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) self.assertDictEqual(features_after_cast, dset.features) # Test we can load from TextClassification template - info.task_templates = None - # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task` - task_with_labels = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: + info1.task_templates = None + with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info1) as dset: with dset.prepare_for_task(task=task_with_labels) as dset: self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) self.assertDictEqual(features_after_cast, dset.features) From f78d5c43d4a5d69030ea20605b91d1c64899e032 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 27 May 2021 11:30:19 +0200 Subject: [PATCH 21/25] Fix log message --- src/datasets/tasks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/tasks/__init__.py b/src/datasets/tasks/__init__.py index e1fe44a04c2..da86ad103d3 100644 --- a/src/datasets/tasks/__init__.py +++ b/src/datasets/tasks/__init__.py @@ -18,7 +18,7 @@ def task_template_from_dict(task_template_dict: dict) -> Optional[TaskTemplate]: """Create one of the supported task templates in :py:mod:`datasets.tasks` from a dictionary.""" task_name = task_template_dict.get("task") if task_name is None: - logger.warning(f"Couldn't find template for rasl '{task_name}'. Available templates: {list(NAME2TEMPLATE)}") + logger.warning(f"Couldn't find template for task '{task_name}'. Available templates: {list(NAME2TEMPLATE)}") return None template = NAME2TEMPLATE.get(task_name) return template.from_dict(task_template_dict) From 514890dd65fefa8a91d51f7a151b667733190c8a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 27 May 2021 11:33:23 +0200 Subject: [PATCH 22/25] Fix comments --- tests/test_arrow_dataset.py | 2 +- tests/test_tasks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index e2e014015af..24a2bf783d1 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2021,7 +2021,7 @@ def test_task_text_classification(self, in_memory): "input_labels": ClassLabel(names=labels), } ) - # Labels are cast to tuple during `TextClassification.__init_`, so we do the same here + # Labels are cast to tuple during `TextClassification.__post_init_`, so we do the same here features_after_cast = Features( { "text": Value("string"), diff --git a/tests/test_tasks.py b/tests/test_tasks.py index bdb5413ec4a..65e3b3ebb7f 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -14,7 +14,7 @@ def test_column_mapping(self): def test_from_dict(self): input_schema = Features({"text": Value("string")}) - # Labels are cast to tuple during TextClassification __post_init__, so we do the same here + # Labels are cast to tuple during `TextClassification.__post_init__`, so we do the same here label_schema = Features({"labels": ClassLabel(names=tuple(self.labels))}) template_dict = {"text_column": "input_text", "label_column": "input_labels", "labels": self.labels} task = TextClassification.from_dict(template_dict) From 40e440080e8f94b2e80b9f93c0d23bee2826ff90 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 28 May 2021 09:25:28 +0200 Subject: [PATCH 23/25] Remove custom feature with lazy classlabel No longer needed since we create a new instance of the task template during the `DatasetInfo.__post_init__` --- src/datasets/tasks/text_classification.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 8608b1ce7a5..e9002024768 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -1,4 +1,3 @@ -import copy from dataclasses import dataclass from typing import ClassVar, Dict, Optional, Tuple @@ -6,29 +5,13 @@ from .base import TaskTemplate -class FeaturesWithLazyClassLabel: - def __init__(self, features, label_column="labels"): - assert label_column in features, f"Key '{label_column}' missing in features {features}" - self._features = features - self._label_column = label_column - - def __get__(self, obj, objtype=None): - if obj is None: - return self._features - - assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'" - features = self._features.copy() - features["labels"] = ClassLabel(names=getattr(obj, self._label_column)) - return features - - @dataclass(frozen=True) class TextClassification(TaskTemplate): # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization task: str = "text-classification" input_schema: ClassVar[Features] = Features({"text": Value("string")}) # TODO(lewtun): Find a more elegant approach without descriptors. - label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) + label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) text_column: str = "text" label_column: str = "labels" labels: Optional[Tuple[str]] = None @@ -38,7 +21,7 @@ def __post_init__(self): assert len(self.labels) == len(set(self.labels)), "Labels must be unique" # Cast labels to tuple to allow hashing self.__dict__["labels"] = tuple(sorted(self.labels)) - self.__dict__["label_schema"] = copy.deepcopy(self.label_schema) + self.__dict__["label_schema"] = self.label_schema.copy() self.label_schema["labels"] = ClassLabel(names=self.labels) @property From 321e4e61490476a07e1c0cac022a35d942f848c1 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 28 May 2021 09:32:23 +0200 Subject: [PATCH 24/25] Move conditional check of features to outer if statement --- src/datasets/info.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 8a3fbf6ff35..27871ef27b0 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -171,10 +171,10 @@ def __post_init__(self): self.task_templates = [template] if template is not None else [] # Insert labels and mappings for text classification - if self.task_templates is not None: + if self.task_templates is not None and self.features is not None: self.task_templates = list(self.task_templates) for idx, template in enumerate(self.task_templates): - if isinstance(template, TextClassification) and self.features is not None: + if isinstance(template, TextClassification): labels = self.features[template.label_column].names self.task_templates[idx] = TextClassification( text_column=template.text_column, label_column=template.label_column, labels=labels From e79dfe4e6b7e6a214171bec88ec98ce31f19fff6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 28 May 2021 09:35:18 +0200 Subject: [PATCH 25/25] Move feature is not None check to inner if-statement --- src/datasets/info.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/datasets/info.py b/src/datasets/info.py index 27871ef27b0..0d1931d8bb1 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -171,14 +171,15 @@ def __post_init__(self): self.task_templates = [template] if template is not None else [] # Insert labels and mappings for text classification - if self.task_templates is not None and self.features is not None: + if self.task_templates is not None: self.task_templates = list(self.task_templates) - for idx, template in enumerate(self.task_templates): - if isinstance(template, TextClassification): - labels = self.features[template.label_column].names - self.task_templates[idx] = TextClassification( - text_column=template.text_column, label_column=template.label_column, labels=labels - ) + if self.features is not None: + for idx, template in enumerate(self.task_templates): + if isinstance(template, TextClassification): + labels = self.features[template.label_column].names + self.task_templates[idx] = TextClassification( + text_column=template.text_column, label_column=template.label_column, labels=labels + ) def _license_path(self, dataset_info_dir): return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)