Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
538f3be
Update labels in DatasetInfo __post_init__
lewtun May 21, 2021
c02a2e4
Add emotion example
lewtun May 21, 2021
2eab30c
Flush task templates before casting
lewtun May 23, 2021
feaca48
Add labels to TextClassification __post_init__
lewtun May 23, 2021
188d02c
Add comment about casting to tuple
lewtun May 23, 2021
1e3e830
Fix capitalisation
lewtun May 23, 2021
635e54d
Refactor tests to account for label update in `DatasetInfo`, add test
lewtun May 23, 2021
a892fde
Merge branch 'master' into refactor-text-clf-template
lewtun May 23, 2021
d85d73d
Merge branch 'master' into refactor-text-clf-template
lewtun May 25, 2021
43f9d55
Update label schema in post_init
lewtun May 25, 2021
5d66b4f
Use __dict__ instead of __setattr__ to update task template labels
lewtun May 25, 2021
e7b1f7a
Raise ValueError if TextClassification template has None or incompati…
lewtun May 25, 2021
6f3ff6d
Remove task templates from emotion demo
lewtun May 25, 2021
1bf0b5b
Add decorator to share docstrings across multiple functions
lewtun May 26, 2021
0dda59e
Update docstring for prepare_for_task
lewtun May 26, 2021
654b2b0
Reorder TextClassification args for better intuition
lewtun May 26, 2021
812bd87
fix missing "task" field in json + edit copy of objects instead of mo…
lhoestq May 26, 2021
159a6f6
style
lhoestq May 26, 2021
a580339
Fix failing tests due to new DatasetInfo.__post_init__
lewtun May 26, 2021
cff9d52
Refactor TextClassification test to cover templates w / w-out labels
lewtun May 27, 2021
8146867
Refactor use of label names in task template concatenation test
lewtun May 27, 2021
fa53dc5
Add separate test for template with labels in DatasetInfo
lewtun May 27, 2021
f78d5c4
Fix log message
lewtun May 27, 2021
514890d
Fix comments
lewtun May 27, 2021
71d5ac8
Merge branch 'master' into refactor-text-clf-template
lewtun May 28, 2021
40e4400
Remove custom feature with lazy classlabel
lewtun May 28, 2021
321e4e6
Move conditional check of features to outer if statement
lewtun May 28, 2021
e79dfe4
Move feature is not None check to inner if-statement
lewtun May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from multiprocess import Pool, RLock
from tqdm.auto import tqdm

from datasets.tasks.text_classification import TextClassification

from . import config
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
Expand Down Expand Up @@ -1388,7 +1390,7 @@ def with_transform(
def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset":
"""Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.

Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema.
Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema. Intended for single-use only, so all task templates are removed from :attr:`datasets.DatasetInfo.task_templates` after casting.

Args:
task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include:
Expand Down Expand Up @@ -1416,10 +1418,18 @@ def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset":
raise ValueError(
f"Expected a `str` or `datasets.tasks.TaskTemplate` object but got task {task} with type {type(task)}."
)
if isinstance(template, TextClassification) and self.info.features is not None:
dataset_labels = tuple(sorted(self.info.features[template.label_column].names))
if template.labels is None or template.labels != dataset_labels:
raise ValueError(
f"Incompatible labels between the dataset and task template! Expected labels {dataset_labels} but got {template.labels}. Please ensure that `datasets.tasks.TextClassification.labels` matches the features of the dataset."
)
column_mapping = template.column_mapping
columns_to_drop = [column for column in self.column_names if column not in column_mapping]
dataset = self.remove_columns(columns_to_drop)
dataset = dataset.rename_columns(column_mapping)
# We found a template so now flush `DatasetInfo` to skip the template update in `DatasetInfo.__post_init__`
dataset.info.task_templates = None
dataset = dataset.cast(features=template.features)
return dataset

Expand Down
15 changes: 3 additions & 12 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import fsspec
import numpy as np

from datasets.utils.doc_utils import is_documented_by

from .arrow_dataset import Dataset
from .features import Features
from .filesystems import extract_path_from_uri, is_remote_filesystem
Expand Down Expand Up @@ -792,18 +794,7 @@ def from_text(
path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
).read()

@is_documented_by(Dataset.prepare_for_task)
def prepare_for_task(self, task: Union[str, TaskTemplate]):
"""Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.

Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema.

Args:
task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include:

- :obj:`"text-classification"`
- :obj:`"question-answering"`

If :obj:`TaskTemplate`, must be one of the task templates in :py:mod:`datasets.tasks`.
"""
self._check_values_type()
return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()})
11 changes: 10 additions & 1 deletion src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union

from datasets.tasks.text_classification import TextClassification

from . import config
from .features import Features, Value
from .features import ClassLabel, Features, Value
from .splits import SplitDict
from .tasks import TaskTemplate, task_template_from_dict
from .utils import Version
Expand Down Expand Up @@ -166,6 +168,13 @@ def __post_init__(self):
else:
template = task_template_from_dict(self.task_templates)
self.task_templates = [template] if template is not None else []
# Insert labels and mappings for text classification
for idx, template in enumerate(self.task_templates):
if isinstance(template, TextClassification) and self.features is not None:
# Cast labels to tuple to enable hashing of task template
template.__dict__["labels"] = tuple(sorted(self.features[template.label_column].names))
template.label_schema["labels"] = ClassLabel(names=template.labels)
self.task_templates[idx] = template

def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)
Expand Down
10 changes: 6 additions & 4 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ class TextClassification(TaskTemplate):
input_schema: ClassVar[Features] = Features({"text": Value("string")})
# TODO(lewtun): Find a more elegant approach without descriptors.
label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))
labels: List[str]
text_column: str = "text"
label_column: str = "labels"
labels: List[str] = None

def __post_init__(self):
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
self.__dict__["labels"] = tuple(sorted(self.labels))
if self.labels:
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
self.__dict__["labels"] = tuple(sorted(self.labels))
self.label_schema["labels"] = ClassLabel(names=self.labels)

@property
def column_mapping(self) -> Dict[str, str]:
Expand Down
15 changes: 15 additions & 0 deletions src/datasets/utils/doc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Callable


def is_documented_by(function_with_docstring: Callable):
"""Decorator to share docstrings across common functions.

Args:
function_with_docstring (`Callable`): Name of the function with the docstring.
"""

def wrapper(target_function):
target_function.__doc__ = function_with_docstring.__doc__
return target_function

return wrapper
157 changes: 135 additions & 22 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,11 @@ def test_concatenate_with_no_task_templates(self, in_memory):
self.assertEqual(dset_concat.info.task_templates, None)

def test_concatenate_with_equal_task_templates(self, in_memory):
task_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"])
info = DatasetInfo(task_templates=task_template)
task_template = TextClassification(text_column="text", label_column="labels")
info = DatasetInfo(
features=Features({"text": Value("string"), "labels": ClassLabel(names=["pos", "neg"])}),
task_templates=task_template,
)
data = {"text": ["i love transformers!"], "labels": [1]}
with tempfile.TemporaryDirectory() as tmp_dir:
with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict(
Expand All @@ -716,10 +719,42 @@ def test_concatenate_with_equal_task_templates(self, in_memory):
self.assertListEqual(dset_concat.info.task_templates, [task_template])

def test_concatenate_with_mixed_task_templates_in_common(self, in_memory):
tc_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"])
tc_template = TextClassification(text_column="text", label_column="labels")
qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers")
info1 = DatasetInfo(task_templates=[qa_template])
info2 = DatasetInfo(task_templates=[qa_template, tc_template])
info1 = DatasetInfo(
task_templates=[qa_template],
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
)
info2 = DatasetInfo(
task_templates=[qa_template, tc_template],
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
)
data = {
"text": ["i love transformers!"],
"labels": [1],
Expand All @@ -738,15 +773,67 @@ def test_concatenate_with_mixed_task_templates_in_common(self, in_memory):
self.assertListEqual(dset_concat.info.task_templates, [qa_template])

def test_concatenate_with_no_mixed_task_templates_in_common(self, in_memory):
tc_template1 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"])
tc_template2 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg", "neutral"])
tc_template1 = TextClassification(text_column="text", label_column="labels")
tc_template2 = TextClassification(text_column="text", label_column="sentiment")
qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers")
info1 = DatasetInfo(task_templates=[tc_template1])
info2 = DatasetInfo(task_templates=[tc_template2])
info3 = DatasetInfo(task_templates=[qa_template])
info1 = DatasetInfo(
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
task_templates=[tc_template1],
)
info2 = DatasetInfo(
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
task_templates=[tc_template2],
)
info3 = DatasetInfo(
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
task_templates=[qa_template],
)
data = {
"text": ["i love transformers!"],
"labels": [1],
"sentiment": [0],
"context": ["huggingface is going to the moon!"],
"question": ["where is huggingface going?"],
"answers": [{"text": ["to the moon!"], "answer_start": [2]}],
Expand Down Expand Up @@ -1929,7 +2016,7 @@ def test_task_text_classification(self, in_memory):
features_before_cast = Features(
{
"input_text": Value("string"),
"input_labels": Value("int32"),
"input_labels": ClassLabel(names=labels),
}
)
# Labels are cast to tuple during TextClassification init, so we do the same here
Expand All @@ -1939,7 +2026,7 @@ def test_task_text_classification(self, in_memory):
"labels": ClassLabel(names=tuple(labels)),
}
)
task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels)
task = TextClassification(text_column="input_text", label_column="input_labels")
info = DatasetInfo(
features=features_before_cast,
task_templates=task,
Expand Down Expand Up @@ -2030,25 +2117,51 @@ def test_task_with_incompatible_templates(self, in_memory):
features = Features(
{
"input_text": Value("string"),
"input_labels": Value("int32"),
"input_labels": ClassLabel(names=labels),
}
)
task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels)
task = TextClassification(text_column="input_text", label_column="input_labels")
info = DatasetInfo(
features=features,
task_templates=task,
)
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
with self.assertRaises(ValueError):
# Invalid task name
dset.prepare_for_task("this-task-does-not-exist")
# Duplicate task templates
dset.info.task_templates = [task, task]
dset.prepare_for_task("text-classification")
# Invalid task type
dset.prepare_for_task(1)
# Invalid task name
self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist")
# Invalid task templates with incompatible labels
task_with_wrong_labels = TextClassification(
text_column="input_text", label_column="input_labels", labels=["neut"]
)
self.assertRaises(ValueError, dset.prepare_for_task, task_with_wrong_labels)
task_with_no_labels = TextClassification(
text_column="input_text", label_column="input_labels", labels=None
)
self.assertRaises(ValueError, dset.prepare_for_task, task_with_no_labels)
# Duplicate task templates
dset.info.task_templates = [task, task]
self.assertRaises(ValueError, dset.prepare_for_task, "text-classification")
# Invalid task type
self.assertRaises(ValueError, dset.prepare_for_task, 1)

def test_task_templates_empty_after_preparation(self, in_memory):
features = Features(
{
"input_text": Value("string"),
"input_labels": ClassLabel(names=["pos", "neg"]),
}
)
task = TextClassification(text_column="input_text", label_column="input_labels")
info = DatasetInfo(
features=features,
task_templates=task,
)
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
with dset.prepare_for_task(task="text-classification") as dset:
self.assertIsNone(dset.info.task_templates)


class MiscellaneousDatasetTest(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_column_mapping(self):

def test_from_dict(self):
input_schema = Features({"text": Value("string")})
# Labels are cast to tuple during TextClassification init, so we do the same here
# Labels are cast to tuple during TextClassification __post_init__, so we do the same here
label_schema = Features({"labels": ClassLabel(names=tuple(self.labels))})
template_dict = {"text_column": "input_text", "label_column": "input_labels", "labels": self.labels}
task = TextClassification.from_dict(template_dict)
Expand Down