Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
538f3be
Update labels in DatasetInfo __post_init__
lewtun May 21, 2021
c02a2e4
Add emotion example
lewtun May 21, 2021
2eab30c
Flush task templates before casting
lewtun May 23, 2021
feaca48
Add labels to TextClassification __post_init__
lewtun May 23, 2021
188d02c
Add comment about casting to tuple
lewtun May 23, 2021
1e3e830
Fix capitalisation
lewtun May 23, 2021
635e54d
Refactor tests to account for label update in `DatasetInfo`, add test
lewtun May 23, 2021
a892fde
Merge branch 'master' into refactor-text-clf-template
lewtun May 23, 2021
d85d73d
Merge branch 'master' into refactor-text-clf-template
lewtun May 25, 2021
43f9d55
Update label schema in post_init
lewtun May 25, 2021
5d66b4f
Use __dict__ instead of __setattr__ to update task template labels
lewtun May 25, 2021
e7b1f7a
Raise ValueError if TextClassification template has None or incompati…
lewtun May 25, 2021
6f3ff6d
Remove task templates from emotion demo
lewtun May 25, 2021
1bf0b5b
Add decorator to share docstrings across multiple functions
lewtun May 26, 2021
0dda59e
Update docstring for prepare_for_task
lewtun May 26, 2021
654b2b0
Reorder TextClassification args for better intuition
lewtun May 26, 2021
812bd87
fix missing "task" field in json + edit copy of objects instead of mo…
lhoestq May 26, 2021
159a6f6
style
lhoestq May 26, 2021
a580339
Fix failing tests due to new DatasetInfo.__post_init__
lewtun May 26, 2021
cff9d52
Refactor TextClassification test to cover templates w / w-out labels
lewtun May 27, 2021
8146867
Refactor use of label names in task template concatenation test
lewtun May 27, 2021
fa53dc5
Add separate test for template with labels in DatasetInfo
lewtun May 27, 2021
f78d5c4
Fix log message
lewtun May 27, 2021
514890d
Fix comments
lewtun May 27, 2021
71d5ac8
Merge branch 'master' into refactor-text-clf-template
lewtun May 28, 2021
40e4400
Remove custom feature with lazy classlabel
lewtun May 28, 2021
321e4e6
Move conditional check of features to outer if statement
lewtun May 28, 2021
e79dfe4
Move feature is not None check to inner if-statement
lewtun May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from multiprocess import Pool, RLock
from tqdm.auto import tqdm

from datasets.tasks.text_classification import TextClassification

from . import config
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
Expand Down Expand Up @@ -1391,7 +1393,7 @@ def with_transform(
def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset":
"""Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.

Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema.
Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. Intended for single-use only, so all task templates are removed from :attr:`datasets.DatasetInfo.task_templates` after casting.

Args:
task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include:
Expand Down Expand Up @@ -1419,10 +1421,18 @@ def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset":
raise ValueError(
f"Expected a `str` or `datasets.tasks.TaskTemplate` object but got task {task} with type {type(task)}."
)
if isinstance(template, TextClassification) and self.info.features is not None:
dataset_labels = tuple(sorted(self.info.features[template.label_column].names))
if template.labels is None or template.labels != dataset_labels:
raise ValueError(
f"Incompatible labels between the dataset and task template! Expected labels {dataset_labels} but got {template.labels}. Please ensure that `datasets.tasks.TextClassification.labels` matches the features of the dataset."
)
column_mapping = template.column_mapping
columns_to_drop = [column for column in self.column_names if column not in column_mapping]
dataset = self.remove_columns(columns_to_drop)
dataset = dataset.rename_columns(column_mapping)
# We found a template so now flush `DatasetInfo` to skip the template update in `DatasetInfo.__post_init__`
dataset.info.task_templates = None
dataset = dataset.cast(features=template.features)
return dataset

Expand Down
15 changes: 3 additions & 12 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import fsspec
import numpy as np

from datasets.utils.doc_utils import is_documented_by

from .arrow_dataset import Dataset
from .features import Features
from .filesystems import extract_path_from_uri, is_remote_filesystem
Expand Down Expand Up @@ -793,18 +795,7 @@ def from_text(
path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
).read()

@is_documented_by(Dataset.prepare_for_task)
def prepare_for_task(self, task: Union[str, TaskTemplate]):
"""Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.

Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema.

Args:
task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include:

- :obj:`"text-classification"`
- :obj:`"question-answering"`

If :obj:`TaskTemplate`, must be one of the task templates in :py:mod:`datasets.tasks`.
"""
self._check_values_type()
return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()})
14 changes: 14 additions & 0 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union

from datasets.tasks.text_classification import TextClassification

from . import config
from .features import Features, Value
from .splits import SplitDict
Expand Down Expand Up @@ -154,6 +156,7 @@ def __post_init__(self):
else:
self.supervised_keys = SupervisedKeysData(**self.supervised_keys)

# Parse and make a list of templates
if self.task_templates is not None:
if isinstance(self.task_templates, (list, tuple)):
templates = [
Expand All @@ -167,6 +170,17 @@ def __post_init__(self):
template = task_template_from_dict(self.task_templates)
self.task_templates = [template] if template is not None else []

# Insert labels and mappings for text classification
if self.task_templates is not None:
self.task_templates = list(self.task_templates)
if self.features is not None:
for idx, template in enumerate(self.task_templates):
if isinstance(template, TextClassification):
labels = self.features[template.label_column].names
self.task_templates[idx] = TextClassification(
text_column=template.text_column, label_column=template.label_column, labels=labels
)

def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)

Expand Down
6 changes: 4 additions & 2 deletions src/datasets/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Optional

from ..utils.logging import get_logger
from .base import TaskTemplate
from .question_answering import QuestionAnswering
from .text_classification import TextClassification


__all__ = ["TaskTemplate", "QuestionAnswering", "TextClassification"]

logger = get_logger(__name__)


NAME2TEMPLATE = {QuestionAnswering.task: QuestionAnswering, TextClassification.task: TextClassification}

Expand All @@ -15,8 +18,7 @@ def task_template_from_dict(task_template_dict: dict) -> Optional[TaskTemplate]:
"""Create one of the supported task templates in :py:mod:`datasets.tasks` from a dictionary."""
task_name = task_template_dict.get("task")
if task_name is None:
logger.warning(f"Couldn't find template for task '{task_name}'. Available templates: {list(NAME2TEMPLATE)}")
return None
template = NAME2TEMPLATE.get(task_name)
if template is None:
return None
return template.from_dict(task_template_dict)
3 changes: 2 additions & 1 deletion src/datasets/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

@dataclass(frozen=True)
class TaskTemplate(abc.ABC):
task: ClassVar[str]
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str
input_schema: ClassVar[Features]
label_schema: ClassVar[Features]

Expand Down
3 changes: 2 additions & 1 deletion src/datasets/tasks/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

@dataclass(frozen=True)
class QuestionAnswering(TaskTemplate):
task: ClassVar[str] = "question-answering"
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "question-answering"
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
label_schema: ClassVar[Features] = Features(
{
Expand Down
42 changes: 11 additions & 31 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,32 @@
from dataclasses import dataclass
from typing import ClassVar, Dict, List
from typing import ClassVar, Dict, Optional, Tuple

from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


class FeaturesWithLazyClassLabel:
def __init__(self, features, label_column="labels"):
assert label_column in features, f"Key '{label_column}' missing in features {features}"
self._features = features
self._label_column = label_column

def __get__(self, obj, objtype=None):
if obj is None:
return self._features

assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'"
features = self._features.copy()
features["labels"] = ClassLabel(names=getattr(obj, self._label_column))
return features


@dataclass(frozen=True)
class TextClassification(TaskTemplate):
task: ClassVar[str] = "text-classification"
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "text-classification"
input_schema: ClassVar[Features] = Features({"text": Value("string")})
# TODO(lewtun): Find a more elegant approach without descriptors.
label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))
labels: List[str]
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
text_column: str = "text"
label_column: str = "labels"
labels: Optional[Tuple[str]] = None

def __post_init__(self):
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
self.__dict__["labels"] = tuple(sorted(self.labels))
if self.labels:
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
self.__dict__["labels"] = tuple(sorted(self.labels))
self.__dict__["label_schema"] = self.label_schema.copy()
self.label_schema["labels"] = ClassLabel(names=self.labels)

@property
def column_mapping(self) -> Dict[str, str]:
return {
self.text_column: "text",
self.label_column: "labels",
}

@property
def label2id(self):
return {label: idx for idx, label in enumerate(self.labels)}

@property
def id2label(self):
return {idx: label for idx, label in enumerate(self.labels)}
15 changes: 15 additions & 0 deletions src/datasets/utils/doc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Callable


def is_documented_by(function_with_docstring: Callable):
"""Decorator to share docstrings across common functions.

Args:
function_with_docstring (`Callable`): Name of the function with the docstring.
"""

def wrapper(target_function):
target_function.__doc__ = function_with_docstring.__doc__
return target_function

return wrapper
Loading