-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Task casting for text classification & question answering #2255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 63 commits
9a19cf2
7ab394a
e5cef8d
4564eb9
29650f8
1516ae6
2157bf5
9302893
02bcfa2
531cf94
d3bccca
6dfd7f2
b0899b1
3cf039d
b2a02c5
28e6ab1
7f0f683
3a30b62
2ec62a2
5df3b65
35e6110
e21d728
8d5427d
c344557
6220e7f
388c5d4
d759f65
6b851aa
95d395e
6d74541
0f6513a
d3cdf78
34aa06e
fe61d2c
46beea9
c83e501
db52ce7
3ad17f3
14e6e5b
316cdb0
8e2299a
40de679
b67ed1f
944b22f
1cb8c93
6858125
6c6b3c9
4d40558
7076902
81674cc
3e39872
7abd2f0
441a36e
894a509
56635e7
1274aa5
5e48403
5a6021a
a868d35
99200c7
f895389
290b583
3578dbd
7a0de24
6bf7970
2dc8d26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,7 @@ | |
| from .search import IndexableMixin | ||
| from .splits import NamedSplit | ||
| from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files | ||
| from .tasks import TaskTemplate | ||
| from .utils import map_nested | ||
| from .utils.deprecation_utils import deprecated | ||
| from .utils.file_utils import estimate_dataset_size | ||
|
|
@@ -1384,6 +1385,44 @@ def with_transform( | |
| dataset.set_transform(transform=transform, columns=columns, output_all_columns=output_all_columns) | ||
| return dataset | ||
|
|
||
| def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset": | ||
| """Prepare a dataset for the given task. | ||
|
|
||
| Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. | ||
lewtun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Args: | ||
| task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include: | ||
|
|
||
| - :obj:`"text-classification"` | ||
| - :obj:`"question-answering"` | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
we can keep a generic Or we can also propose to use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could have
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I see that people are more organizing in terms of general and sub-tasks, for instance on paperwithcode: https://paperswithcode.com/area/natural-language-processing and on nlpprogress: https://github.com/sebastianruder/NLP-progress/blob/master/english/question_answering.md#squad Probably the best is to align with one of these in terms of denomination, PaperWithCode is probably the most active and maintained and we work with them as well.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. our idea was to start by following the pipeline taxonomy in but you're 100% correct that we should already start considering the abstractive case or grouping in terms of the sub-domains that PwC adopts. since our focus right now is on getting
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PWC uses "Question Answering" (meaning extractive) vs "Generative Question Answering" (includes abstractive) which encapsulates most of the QA tasks except Open/Closed Domain QA and Knowledge base QA (they require no context since the knowledge is not part of the query). I'm fine with these names, or simply extractive and abstractive |
||
|
|
||
| If :obj:`TaskTemplate`, must be one of the task templates in :obj:`datasets.tasks`. | ||
| """ | ||
| # TODO(lewtun): Add support for casting nested features like answers.text and answers.answer_start in SQuAD | ||
| if isinstance(task, str): | ||
| tasks = [template.task for template in (self.info.task_templates or [])] | ||
| compatible_templates = [template for template in (self.info.task_templates or []) if template.task == task] | ||
| if not compatible_templates: | ||
| raise ValueError(f"Task {task} is not compatible with this dataset! Available tasks: {tasks}") | ||
|
|
||
| if len(compatible_templates) > 1: | ||
| raise ValueError( | ||
| f"Expected 1 task template but found {len(compatible_templates)}! Please ensure that `datasets.DatasetInfo.task_templates` contains a unique set of task types." | ||
| ) | ||
| template = compatible_templates[0] | ||
| elif isinstance(task, TaskTemplate): | ||
| template = task | ||
| else: | ||
| raise ValueError( | ||
| f"Expected a `str` or `datasets.tasks.TaskTemplate` object but got task {task} with type {type(task)}." | ||
| ) | ||
| column_mapping = template.column_mapping | ||
| columns_to_drop = [column for column in self.column_names if column not in column_mapping] | ||
| dataset = self.remove_columns(columns_to_drop) | ||
| dataset = dataset.rename_columns(column_mapping) | ||
| dataset = dataset.cast(features=template.features) | ||
| return dataset | ||
|
|
||
| def _getitem( | ||
| self, | ||
| key: Union[int, slice, str], | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -38,6 +38,7 @@ | |||||
| from .metric import Metric | ||||||
| from .packaged_modules import _PACKAGED_DATASETS_MODULES, hash_python_lines | ||||||
| from .splits import Split | ||||||
| from .tasks import TaskTemplate | ||||||
| from .utils.download_manager import GenerateMode | ||||||
| from .utils.file_utils import ( | ||||||
| DownloadConfig, | ||||||
|
|
@@ -635,6 +636,7 @@ def load_dataset( | |||||
| save_infos: bool = False, | ||||||
| script_version: Optional[Union[str, Version]] = None, | ||||||
| use_auth_token: Optional[Union[bool, str]] = None, | ||||||
| task: Optional[Union[str, TaskTemplate]] = None, | ||||||
| **config_kwargs, | ||||||
| ) -> Union[DatasetDict, Dataset]: | ||||||
| """Load a dataset. | ||||||
|
|
@@ -694,6 +696,7 @@ def load_dataset( | |||||
| You can specify a different version that the default "main" by using a commit sha or a git tag of the dataset repository. | ||||||
| use_auth_token (``str`` or ``bool``, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. | ||||||
| If True, will get token from `"~/.huggingface"`. | ||||||
| task (``str``): The task to prepare the dataset for during training and evaluation. Casts the dataset's :class:`Features` according to one of the schemas in `~tasks`. | ||||||
|
||||||
| task (``str``): The task to prepare the dataset for during training and evaluation. Casts the dataset's :class:`Features` according to one of the schemas in `~tasks`. | |
| task (``str``): The task to prepare the dataset for during training and evaluation. Casts the dataset's :class:`Features` according to standardized column names and types as detailed in `~tasks`. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from typing import Optional | ||
|
|
||
| from .base import TaskTemplate | ||
| from .question_answering import QuestionAnswering | ||
| from .text_classification import TextClassification | ||
|
|
||
|
|
||
| __all__ = ["TaskTemplate", "QuestionAnswering", "TextClassification"] | ||
|
|
||
|
|
||
| NAME2TEMPLATE = {QuestionAnswering.task: QuestionAnswering, TextClassification.task: TextClassification} | ||
|
|
||
|
|
||
| def task_template_from_dict(task_template_dict: dict) -> Optional[TaskTemplate]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function supposed to be used by the user? If so it should have a doc string and doc
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no this is not supposed to be used by the user (at least for now). still no harm in having a docstring, so i've added:
|
||
| task_name = task_template_dict.get("task") | ||
| if task_name is None: | ||
| return None | ||
| template = NAME2TEMPLATE.get(task_name) | ||
| if template is None: | ||
| return None | ||
| return template.from_dict(task_template_dict) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import abc | ||
| from dataclasses import dataclass | ||
| from typing import ClassVar, Dict | ||
|
|
||
| from ..features import Features | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TaskTemplate(abc.ABC): | ||
| task: ClassVar[str] | ||
| input_schema: ClassVar[Features] | ||
| label_schema: ClassVar[Features] | ||
|
|
||
| @property | ||
| def features(self) -> Features: | ||
| return Features(**self.input_schema, **self.label_schema) | ||
|
|
||
| @property | ||
| @abc.abstractmethod | ||
| def column_mapping(self) -> Dict[str, str]: | ||
| return NotImplemented | ||
|
|
||
| @classmethod | ||
| @abc.abstractmethod | ||
| def from_dict(cls, template_dict: dict) -> "TaskTemplate": | ||
| return NotImplemented |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Dict | ||
|
|
||
| from ..features import Features, Sequence, Value | ||
| from .base import TaskTemplate | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class QuestionAnswering(TaskTemplate): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the suggestion! i'll keep this unchanged for now since it will soon be overhauled by #2371 |
||
| task = "question-answering" | ||
| input_schema = Features({"question": Value("string"), "context": Value("string")}) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you want to check with a few QA datasets that this schema make sense. Typically NaturalQuestions, TriviaQA and can be good second datasets to compare to and be sure of the generality of the schema. A good recent list of QA datasets to compare the schemas among, is for instance in the UnitedQA paper: https://arxiv.org/abs/2101.00178
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the tip! added to #2371 |
||
| label_schema = Features( | ||
| { | ||
| "answers": Sequence( | ||
| { | ||
| "text": Value("string"), | ||
| "answer_start": Value("int32"), | ||
| } | ||
| ) | ||
| } | ||
| ) | ||
| question_column: str = "question" | ||
| context_column: str = "context" | ||
| answers_column: str = "answers" | ||
|
|
||
| def __post_init__(self): | ||
| object.__setattr__(self, "question_column", self.question_column) | ||
| object.__setattr__(self, "context_column", self.context_column) | ||
| object.__setattr__(self, "answers_column", self.answers_column) | ||
|
|
||
| @property | ||
| def column_mapping(self) -> Dict[str, str]: | ||
| return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"} | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, template_dict: dict) -> "QuestionAnswering": | ||
| return cls( | ||
| question_column=template_dict["question_column"], | ||
| context_column=template_dict["context_column"], | ||
| answers_column=template_dict["answers_column"], | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Dict, List | ||
|
|
||
| from ..features import ClassLabel, Features, Value | ||
| from .base import TaskTemplate | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TextClassification(TaskTemplate): | ||
| task = "text-classification" | ||
| input_schema = Features({"text": Value("string")}) | ||
| # TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so | ||
| # investigate if there's a more elegant approach. | ||
| label_schema = Features({"labels": ClassLabel}) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Since we update this in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! i'll fix it :)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah actually we need to declare a default because |
||
| labels: List[str] | ||
| text_column: str = "text" | ||
| label_column: str = "labels" | ||
|
|
||
| def __post_init__(self): | ||
| assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique" | ||
| # Cast labels to tuple to allow hashing | ||
| object.__setattr__(self, "labels", tuple(sorted(self.labels))) | ||
| object.__setattr__(self, "text_column", self.text_column) | ||
| object.__setattr__(self, "label_column", self.label_column) | ||
| self.label_schema["labels"] = ClassLabel(names=self.labels) | ||
| object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)}) | ||
| object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()}) | ||
|
|
||
| @property | ||
| def column_mapping(self) -> Dict[str, str]: | ||
| return { | ||
| self.text_column: "text", | ||
| self.label_column: "labels", | ||
| } | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, template_dict: dict) -> "TextClassification": | ||
| return cls( | ||
| text_column=template_dict["text_column"], | ||
| label_column=template_dict["label_column"], | ||
| labels=template_dict["labels"], | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's good to explain a bit what are the operations contained in the expression "prepare a dataset"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea! i opted for the following: