diff --git a/src/datasets/tasks/__init__.py b/src/datasets/tasks/__init__.py index f4b3021aa9d..94cca357a4d 100644 --- a/src/datasets/tasks/__init__.py +++ b/src/datasets/tasks/__init__.py @@ -3,10 +3,11 @@ from ..utils.logging import get_logger from .base import TaskTemplate from .question_answering import QuestionAnsweringExtractive +from .summarization import Summarization from .text_classification import TextClassification -__all__ = ["TaskTemplate", "QuestionAnsweringExtractive", "TextClassification"] +__all__ = ["TaskTemplate", "QuestionAnsweringExtractive", "TextClassification", "Summarization"] logger = get_logger(__name__) @@ -14,6 +15,7 @@ NAME2TEMPLATE = { QuestionAnsweringExtractive.task: QuestionAnsweringExtractive, TextClassification.task: TextClassification, + Summarization.task: Summarization, } diff --git a/src/datasets/tasks/summarization.py b/src/datasets/tasks/summarization.py new file mode 100644 index 00000000000..0e99b9d3b7d --- /dev/null +++ b/src/datasets/tasks/summarization.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import ClassVar, Dict + +from ..features import Features, Value +from .base import TaskTemplate + + +@dataclass(frozen=True) +class Summarization(TaskTemplate): + # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization + task: str = "summarization" + input_schema: ClassVar[Features] = Features({"text": Value("string")}) + label_schema: ClassVar[Features] = Features({"summary": Value("string")}) + text_column: str = "text" + summary_column: str = "summary" + + @property + def column_mapping(self) -> Dict[str, str]: + return {self.text_column: "text", self.summary_column: "summary"} diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 23f917e8430..0f337f64acf 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -20,7 +20,7 @@ from datasets.info import DatasetInfo from datasets.splits import NamedSplit from datasets.table import ConcatenationTable, InMemoryTable, MemoryMappedTable -from datasets.tasks import QuestionAnsweringExtractive, TextClassification +from datasets.tasks import QuestionAnsweringExtractive, Summarization, TextClassification from datasets.utils.logging import WARNING from .conftest import s3_test_bucket_name @@ -705,172 +705,6 @@ def test_concatenate_pickle(self, in_memory): self.assertEqual(dset_concat.info.description, "Dataset1\n\nDataset2\n\n") del dset1, dset2, dset3 - def test_concatenate_with_no_task_templates(self, in_memory): - info = DatasetInfo(task_templates=None) - data = {"text": ["i love transformers!"], "labels": [1]} - with tempfile.TemporaryDirectory() as tmp_dir: - with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict( - data, info=info - ) as dset2, Dataset.from_dict(data, info=info) as dset3: - with self._to(in_memory, tmp_dir, dset1) as dset1, self._to( - in_memory, tmp_dir, dset2 - ) as dset2, self._to(in_memory, tmp_dir, dset3) as dset3: - with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: - self.assertEqual(dset_concat.info.task_templates, None) - - def test_concatenate_with_equal_task_templates(self, in_memory): - labels = ["neg", "pos"] - task_template = TextClassification(text_column="text", label_column="labels", labels=labels) - info = DatasetInfo( - features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}), - # Label names are added in `DatasetInfo.__post_init__` so not included here - task_templates=TextClassification(text_column="text", label_column="labels"), - ) - data = {"text": ["i love transformers!"], "labels": [1]} - with tempfile.TemporaryDirectory() as tmp_dir: - with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict( - data, info=info - ) as dset2, Dataset.from_dict(data, info=info) as dset3: - with self._to(in_memory, tmp_dir, dset1) as dset1, self._to( - in_memory, tmp_dir, dset2 - ) as dset2, self._to(in_memory, tmp_dir, dset3) as dset3: - with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: - self.assertListEqual(dset_concat.info.task_templates, [task_template]) - - def test_concatenate_with_mixed_task_templates_in_common(self, in_memory): - tc_template = TextClassification(text_column="text", label_column="labels") - qa_template = QuestionAnsweringExtractive( - question_column="question", context_column="context", answers_column="answers" - ) - info1 = DatasetInfo( - task_templates=[qa_template], - features=Features( - { - "text": Value("string"), - "labels": ClassLabel(names=["pos", "neg"]), - "context": Value("string"), - "question": Value("string"), - "answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ), - ) - info2 = DatasetInfo( - task_templates=[qa_template, tc_template], - features=Features( - { - "text": Value("string"), - "labels": ClassLabel(names=["pos", "neg"]), - "context": Value("string"), - "question": Value("string"), - "answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ), - ) - data = { - "text": ["i love transformers!"], - "labels": [1], - "context": ["huggingface is going to the moon!"], - "question": ["where is huggingface going?"], - "answers": [{"text": ["to the moon!"], "answer_start": [2]}], - } - with tempfile.TemporaryDirectory() as tmp_dir: - with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict( - data, info=info2 - ) as dset2, Dataset.from_dict(data, info=info2) as dset3: - with self._to(in_memory, tmp_dir, dset1) as dset1, self._to( - in_memory, tmp_dir, dset2 - ) as dset2, self._to(in_memory, tmp_dir, dset3) as dset3: - with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: - self.assertListEqual(dset_concat.info.task_templates, [qa_template]) - - def test_concatenate_with_no_mixed_task_templates_in_common(self, in_memory): - tc_template1 = TextClassification(text_column="text", label_column="labels") - tc_template2 = TextClassification(text_column="text", label_column="sentiment") - qa_template = QuestionAnsweringExtractive( - question_column="question", context_column="context", answers_column="answers" - ) - info1 = DatasetInfo( - features=Features( - { - "text": Value("string"), - "labels": ClassLabel(names=["pos", "neg"]), - "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), - "context": Value("string"), - "question": Value("string"), - "answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ), - task_templates=[tc_template1], - ) - info2 = DatasetInfo( - features=Features( - { - "text": Value("string"), - "labels": ClassLabel(names=["pos", "neg"]), - "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), - "context": Value("string"), - "question": Value("string"), - "answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ), - task_templates=[tc_template2], - ) - info3 = DatasetInfo( - features=Features( - { - "text": Value("string"), - "labels": ClassLabel(names=["pos", "neg"]), - "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), - "context": Value("string"), - "question": Value("string"), - "answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ), - task_templates=[qa_template], - ) - data = { - "text": ["i love transformers!"], - "labels": [1], - "sentiment": [0], - "context": ["huggingface is going to the moon!"], - "question": ["where is huggingface going?"], - "answers": [{"text": ["to the moon!"], "answer_start": [2]}], - } - with tempfile.TemporaryDirectory() as tmp_dir: - with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict( - data, info=info2 - ) as dset2, Dataset.from_dict(data, info=info3) as dset3: - with self._to(in_memory, tmp_dir, dset1) as dset1, self._to( - in_memory, tmp_dir, dset2 - ) as dset2, self._to(in_memory, tmp_dir, dset3) as dset3: - with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: - self.assertEqual(dset_concat.info.task_templates, None) - def test_flatten(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with Dataset.from_dict( @@ -2034,193 +1868,6 @@ def test_with_transform(self, in_memory): self.assertNotEqual(dset.format, dset2.format) self.assertNotEqual(dset._fingerprint, dset2._fingerprint) - def test_task_text_classification(self, in_memory): - labels = sorted(["pos", "neg"]) - features_before_cast = Features( - { - "input_text": Value("string"), - "input_labels": ClassLabel(names=labels), - } - ) - # Labels are cast to tuple during `TextClassification.__post_init_`, so we do the same here - features_after_cast = Features( - { - "text": Value("string"), - "labels": ClassLabel(names=tuple(labels)), - } - ) - # Label names are added in `DatasetInfo.__post_init__` so not needed here - task_without_labels = TextClassification(text_column="input_text", label_column="input_labels") - info1 = DatasetInfo( - features=features_before_cast, - task_templates=task_without_labels, - ) - # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task` - # However they also can be used to define `DatasetInfo` so we include a test for this too - task_with_labels = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) - info2 = DatasetInfo( - features=features_before_cast, - task_templates=task_with_labels, - ) - data = {"input_text": ["i love transformers!"], "input_labels": [1]} - # Test we can load from task name when label names not included in template (default behaviour) - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info1) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) - self.assertDictEqual(features_before_cast, dset.features) - with dset.prepare_for_task(task="text-classification") as dset: - self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) - self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from task name when label names included in template - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info2) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) - self.assertDictEqual(features_before_cast, dset.features) - with dset.prepare_for_task(task="text-classification") as dset: - self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) - self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from TextClassification template - info1.task_templates = None - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info1) as dset: - with dset.prepare_for_task(task=task_with_labels) as dset: - self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) - self.assertDictEqual(features_after_cast, dset.features) - - def test_task_question_answering(self, in_memory): - features_before_cast = Features( - { - "input_context": Value("string"), - "input_question": Value("string"), - "input_answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ) - features_after_cast = Features( - { - "context": Value("string"), - "question": Value("string"), - "answers": Sequence( - { - "text": Value("string"), - "answer_start": Value("int32"), - } - ), - } - ) - task = QuestionAnsweringExtractive( - context_column="input_context", question_column="input_question", answers_column="input_answers" - ) - info = DatasetInfo(features=features_before_cast, task_templates=task) - data = { - "input_context": ["huggingface is going to the moon!"], - "input_question": ["where is huggingface going?"], - "input_answers": [{"text": ["to the moon!"], "answer_start": [2]}], - } - # Test we can load from task name - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - self.assertSetEqual( - set(["input_context", "input_question", "input_answers.text", "input_answers.answer_start"]), - set(dset.flatten().column_names), - ) - self.assertDictEqual(features_before_cast, dset.features) - with dset.prepare_for_task(task="question-answering-extractive") as dset: - self.assertSetEqual( - set(["context", "question", "answers.text", "answers.answer_start"]), - set(dset.flatten().column_names), - ) - self.assertDictEqual(features_after_cast, dset.features) - # Test we can load from QuestionAnsweringExtractive template - info.task_templates = None - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: - with dset.prepare_for_task(task=task) as dset: - self.assertSetEqual( - set(["context", "question", "answers.text", "answers.answer_start"]), - set(dset.flatten().column_names), - ) - self.assertDictEqual(features_after_cast, dset.features) - - def test_task_with_no_template(self, in_memory): - data = {"input_text": ["i love transformers!"], "input_labels": [1]} - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - with self.assertRaises(ValueError): - dset.prepare_for_task("text-classification") - - def test_task_with_incompatible_templates(self, in_memory): - labels = sorted(["pos", "neg"]) - features = Features( - { - "input_text": Value("string"), - "input_labels": ClassLabel(names=labels), - } - ) - task = TextClassification(text_column="input_text", label_column="input_labels") - info = DatasetInfo( - features=features, - task_templates=task, - ) - data = {"input_text": ["i love transformers!"], "input_labels": [1]} - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - # Invalid task name - self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist") - # Invalid task templates with incompatible labels - task_with_wrong_labels = TextClassification( - text_column="input_text", label_column="input_labels", labels=["neut"] - ) - self.assertRaises(ValueError, dset.prepare_for_task, task_with_wrong_labels) - task_with_no_labels = TextClassification( - text_column="input_text", label_column="input_labels", labels=None - ) - self.assertRaises(ValueError, dset.prepare_for_task, task_with_no_labels) - # Duplicate task templates - dset.info.task_templates = [task, task] - self.assertRaises(ValueError, dset.prepare_for_task, "text-classification") - # Invalid task type - self.assertRaises(ValueError, dset.prepare_for_task, 1) - - def test_task_templates_empty_after_preparation(self, in_memory): - features = Features( - { - "input_text": Value("string"), - "input_labels": ClassLabel(names=["pos", "neg"]), - } - ) - task = TextClassification(text_column="input_text", label_column="input_labels") - info = DatasetInfo( - features=features, - task_templates=task, - ) - data = {"input_text": ["i love transformers!"], "input_labels": [1]} - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - with dset.prepare_for_task(task="text-classification") as dset: - self.assertIsNone(dset.info.task_templates) - - def test_align_labels_with_mapping(self, in_memory): - features = Features( - { - "input_text": Value("string"), - "input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]), - } - ) - data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]} - label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1} - id2label = {v: k for k, v in label2id.items()} - expected_labels = [2, 2, 1, 1, 0, 0] - expected_label_names = [id2label[idx] for idx in expected_labels] - with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, features=features) as dset: - with self._to(in_memory, tmp_dir, dset) as dset: - with dset.align_labels_with_mapping(label2id, "input_labels") as dset: - self.assertListEqual(expected_labels, dset["input_labels"]) - aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]] - self.assertListEqual(expected_label_names, aligned_label_names) - class MiscellaneousDatasetTest(TestCase): def test_from_pandas(self): @@ -2763,3 +2410,366 @@ def test_dummy_dataset_serialize_s3(s3, dataset): assert dataset.features == features assert dataset[0]["id"] == 0 assert dataset["id"][0] == 0 + + +class TaskTemplatesTest(TestCase): + def test_task_text_classification(self): + labels = sorted(["pos", "neg"]) + features_before_cast = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(names=labels), + } + ) + # Labels are cast to tuple during `TextClassification.__post_init_`, so we do the same here + features_after_cast = Features( + { + "text": Value("string"), + "labels": ClassLabel(names=tuple(labels)), + } + ) + # Label names are added in `DatasetInfo.__post_init__` so not needed here + task_without_labels = TextClassification(text_column="input_text", label_column="input_labels") + info1 = DatasetInfo( + features=features_before_cast, + task_templates=task_without_labels, + ) + # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task` + # However they also can be used to define `DatasetInfo` so we include a test for this too + task_with_labels = TextClassification(text_column="input_text", label_column="input_labels", labels=labels) + info2 = DatasetInfo( + features=features_before_cast, + task_templates=task_with_labels, + ) + data = {"input_text": ["i love transformers!"], "input_labels": [1]} + # Test we can load from task name when label names not included in template (default behaviour) + with Dataset.from_dict(data, info=info1) as dset: + self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) + self.assertDictEqual(features_before_cast, dset.features) + with dset.prepare_for_task(task="text-classification") as dset: + self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) + self.assertDictEqual(features_after_cast, dset.features) + # Test we can load from task name when label names included in template + with Dataset.from_dict(data, info=info2) as dset: + self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names)) + self.assertDictEqual(features_before_cast, dset.features) + with dset.prepare_for_task(task="text-classification") as dset: + self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) + self.assertDictEqual(features_after_cast, dset.features) + # Test we can load from TextClassification template + info1.task_templates = None + with Dataset.from_dict(data, info=info1) as dset: + with dset.prepare_for_task(task=task_with_labels) as dset: + self.assertSetEqual(set(["labels", "text"]), set(dset.column_names)) + self.assertDictEqual(features_after_cast, dset.features) + + def test_task_question_answering(self): + features_before_cast = Features( + { + "input_context": Value("string"), + "input_question": Value("string"), + "input_answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ) + features_after_cast = Features( + { + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ) + task = QuestionAnsweringExtractive( + context_column="input_context", question_column="input_question", answers_column="input_answers" + ) + info = DatasetInfo(features=features_before_cast, task_templates=task) + data = { + "input_context": ["huggingface is going to the moon!"], + "input_question": ["where is huggingface going?"], + "input_answers": [{"text": ["to the moon!"], "answer_start": [2]}], + } + # Test we can load from task name + with Dataset.from_dict(data, info=info) as dset: + self.assertSetEqual( + set(["input_context", "input_question", "input_answers.text", "input_answers.answer_start"]), + set(dset.flatten().column_names), + ) + self.assertDictEqual(features_before_cast, dset.features) + with dset.prepare_for_task(task="question-answering-extractive") as dset: + self.assertSetEqual( + set(["context", "question", "answers.text", "answers.answer_start"]), + set(dset.flatten().column_names), + ) + self.assertDictEqual(features_after_cast, dset.features) + # Test we can load from QuestionAnsweringExtractive template + info.task_templates = None + with Dataset.from_dict(data, info=info) as dset: + with dset.prepare_for_task(task=task) as dset: + self.assertSetEqual( + set(["context", "question", "answers.text", "answers.answer_start"]), + set(dset.flatten().column_names), + ) + self.assertDictEqual(features_after_cast, dset.features) + + def test_task_summarization(self): + # Include a dummy extra column `dummy` to test we drop it correctly + features_before_cast = Features( + {"input_text": Value("string"), "input_summary": Value("string"), "dummy": Value("string")} + ) + features_after_cast = Features({"text": Value("string"), "summary": Value("string")}) + task = Summarization(text_column="input_text", summary_column="input_summary") + info = DatasetInfo(features=features_before_cast, task_templates=task) + data = { + "input_text": ["jack and jill took a taxi to attend a super duper party in the city."], + "input_summary": ["jack and jill attend party"], + "dummy": ["123456"], + } + # Test we can load from task name + with Dataset.from_dict(data, info=info) as dset: + with dset.prepare_for_task(task="summarization") as dset: + self.assertSetEqual( + set(["text", "summary"]), + set(dset.column_names), + ) + self.assertDictEqual(features_after_cast, dset.features) + # Test we can load from Summarization template + info.task_templates = None + with Dataset.from_dict(data, info=info) as dset: + with dset.prepare_for_task(task=task) as dset: + self.assertSetEqual( + set(["text", "summary"]), + set(dset.column_names), + ) + self.assertDictEqual(features_after_cast, dset.features) + + def test_task_with_no_template(self): + data = {"input_text": ["i love transformers!"], "input_labels": [1]} + with Dataset.from_dict(data) as dset: + with self.assertRaises(ValueError): + dset.prepare_for_task("text-classification") + + def test_task_with_incompatible_templates(self): + labels = sorted(["pos", "neg"]) + features = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(names=labels), + } + ) + task = TextClassification(text_column="input_text", label_column="input_labels") + info = DatasetInfo( + features=features, + task_templates=task, + ) + data = {"input_text": ["i love transformers!"], "input_labels": [1]} + with Dataset.from_dict(data, info=info) as dset: + # Invalid task name + self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist") + # Invalid task templates with incompatible labels + task_with_wrong_labels = TextClassification( + text_column="input_text", label_column="input_labels", labels=["neut"] + ) + self.assertRaises(ValueError, dset.prepare_for_task, task_with_wrong_labels) + task_with_no_labels = TextClassification( + text_column="input_text", label_column="input_labels", labels=None + ) + self.assertRaises(ValueError, dset.prepare_for_task, task_with_no_labels) + # Duplicate task templates + dset.info.task_templates = [task, task] + self.assertRaises(ValueError, dset.prepare_for_task, "text-classification") + # Invalid task type + self.assertRaises(ValueError, dset.prepare_for_task, 1) + + def test_task_templates_empty_after_preparation(self): + features = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(names=["pos", "neg"]), + } + ) + task = TextClassification(text_column="input_text", label_column="input_labels") + info = DatasetInfo( + features=features, + task_templates=task, + ) + data = {"input_text": ["i love transformers!"], "input_labels": [1]} + with Dataset.from_dict(data, info=info) as dset: + with dset.prepare_for_task(task="text-classification") as dset: + self.assertIsNone(dset.info.task_templates) + + def test_align_labels_with_mapping(self): + features = Features( + { + "input_text": Value("string"), + "input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]), + } + ) + data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]} + label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1} + id2label = {v: k for k, v in label2id.items()} + expected_labels = [2, 2, 1, 1, 0, 0] + expected_label_names = [id2label[idx] for idx in expected_labels] + with Dataset.from_dict(data, features=features) as dset: + with dset.align_labels_with_mapping(label2id, "input_labels") as dset: + self.assertListEqual(expected_labels, dset["input_labels"]) + aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]] + self.assertListEqual(expected_label_names, aligned_label_names) + + def test_concatenate_with_no_task_templates(self): + info = DatasetInfo(task_templates=None) + data = {"text": ["i love transformers!"], "labels": [1]} + with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict( + data, info=info + ) as dset2, Dataset.from_dict(data, info=info) as dset3: + with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: + self.assertEqual(dset_concat.info.task_templates, None) + + def test_concatenate_with_equal_task_templates(self): + labels = ["neg", "pos"] + task_template = TextClassification(text_column="text", label_column="labels", labels=labels) + info = DatasetInfo( + features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}), + # Label names are added in `DatasetInfo.__post_init__` so not included here + task_templates=TextClassification(text_column="text", label_column="labels"), + ) + data = {"text": ["i love transformers!"], "labels": [1]} + with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict( + data, info=info + ) as dset2, Dataset.from_dict(data, info=info) as dset3: + with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: + self.assertListEqual(dset_concat.info.task_templates, [task_template]) + + def test_concatenate_with_mixed_task_templates_in_common(self): + tc_template = TextClassification(text_column="text", label_column="labels") + qa_template = QuestionAnsweringExtractive( + question_column="question", context_column="context", answers_column="answers" + ) + info1 = DatasetInfo( + task_templates=[qa_template], + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + ) + info2 = DatasetInfo( + task_templates=[qa_template, tc_template], + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + ) + data = { + "text": ["i love transformers!"], + "labels": [1], + "context": ["huggingface is going to the moon!"], + "question": ["where is huggingface going?"], + "answers": [{"text": ["to the moon!"], "answer_start": [2]}], + } + with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict( + data, info=info2 + ) as dset2, Dataset.from_dict(data, info=info2) as dset3: + with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: + self.assertListEqual(dset_concat.info.task_templates, [qa_template]) + + def test_concatenate_with_no_mixed_task_templates_in_common(self): + tc_template1 = TextClassification(text_column="text", label_column="labels") + tc_template2 = TextClassification(text_column="text", label_column="sentiment") + qa_template = QuestionAnsweringExtractive( + question_column="question", context_column="context", answers_column="answers" + ) + info1 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[tc_template1], + ) + info2 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[tc_template2], + ) + info3 = DatasetInfo( + features=Features( + { + "text": Value("string"), + "labels": ClassLabel(names=["pos", "neg"]), + "sentiment": ClassLabel(names=["pos", "neg", "neutral"]), + "context": Value("string"), + "question": Value("string"), + "answers": Sequence( + { + "text": Value("string"), + "answer_start": Value("int32"), + } + ), + } + ), + task_templates=[qa_template], + ) + data = { + "text": ["i love transformers!"], + "labels": [1], + "sentiment": [0], + "context": ["huggingface is going to the moon!"], + "question": ["where is huggingface going?"], + "answers": [{"text": ["to the moon!"], "answer_start": [2]}], + } + with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict( + data, info=info2 + ) as dset2, Dataset.from_dict(data, info=info3) as dset3: + with concatenate_datasets([dset1, dset2, dset3]) as dset_concat: + self.assertEqual(dset_concat.info.task_templates, None) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 0d3878dd135..161ef16934a 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,7 +1,7 @@ from unittest.case import TestCase from datasets.features import ClassLabel, Features, Sequence, Value -from datasets.tasks import QuestionAnsweringExtractive, TextClassification +from datasets.tasks import QuestionAnsweringExtractive, Summarization, TextClassification class TextClassificationTest(TestCase): @@ -53,3 +53,18 @@ def test_from_dict(self): self.assertEqual("question-answering-extractive", task.task) self.assertEqual(input_schema, task.input_schema) self.assertEqual(label_schema, task.label_schema) + + +class SummarizationTest(TestCase): + def test_column_mapping(self): + task = Summarization(text_column="input_text", summary_column="input_summary") + self.assertDictEqual({"input_text": "text", "input_summary": "summary"}, task.column_mapping) + + def test_from_dict(self): + input_schema = Features({"text": Value("string")}) + label_schema = Features({"summary": Value("string")}) + template_dict = {"text_column": "input_text", "summary_column": "input_summary"} + task = Summarization.from_dict(template_dict) + self.assertEqual("summarization", task.task) + self.assertEqual(input_schema, task.input_schema) + self.assertEqual(label_schema, task.label_schema)