Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
538f3be
Update labels in DatasetInfo __post_init__
lewtun May 21, 2021
c02a2e4
Add emotion example
lewtun May 21, 2021
2eab30c
Flush task templates before casting
lewtun May 23, 2021
feaca48
Add labels to TextClassification __post_init__
lewtun May 23, 2021
188d02c
Add comment about casting to tuple
lewtun May 23, 2021
1e3e830
Fix capitalisation
lewtun May 23, 2021
635e54d
Refactor tests to account for label update in `DatasetInfo`, add test
lewtun May 23, 2021
a892fde
Merge branch 'master' into refactor-text-clf-template
lewtun May 23, 2021
d85d73d
Merge branch 'master' into refactor-text-clf-template
lewtun May 25, 2021
43f9d55
Update label schema in post_init
lewtun May 25, 2021
5d66b4f
Use __dict__ instead of __setattr__ to update task template labels
lewtun May 25, 2021
e7b1f7a
Raise ValueError if TextClassification template has None or incompati…
lewtun May 25, 2021
6f3ff6d
Remove task templates from emotion demo
lewtun May 25, 2021
1bf0b5b
Add decorator to share docstrings across multiple functions
lewtun May 26, 2021
0dda59e
Update docstring for prepare_for_task
lewtun May 26, 2021
654b2b0
Reorder TextClassification args for better intuition
lewtun May 26, 2021
812bd87
fix missing "task" field in json + edit copy of objects instead of mo…
lhoestq May 26, 2021
159a6f6
style
lhoestq May 26, 2021
a580339
Fix failing tests due to new DatasetInfo.__post_init__
lewtun May 26, 2021
cff9d52
Refactor TextClassification test to cover templates w / w-out labels
lewtun May 27, 2021
8146867
Refactor use of label names in task template concatenation test
lewtun May 27, 2021
fa53dc5
Add separate test for template with labels in DatasetInfo
lewtun May 27, 2021
f78d5c4
Fix log message
lewtun May 27, 2021
514890d
Fix comments
lewtun May 27, 2021
71d5ac8
Merge branch 'master' into refactor-text-clf-template
lewtun May 28, 2021
40e4400
Remove custom feature with lazy classlabel
lewtun May 28, 2021
321e4e6
Move conditional check of features to outer if statement
lewtun May 28, 2021
e79dfe4
Move feature is not None check to inner if-statement
lewtun May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datasets/emotion/emotion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv

import datasets
from datasets.tasks.text_classification import TextClassification


_CITATION = """\
Expand Down Expand Up @@ -44,6 +45,12 @@ def _info(self):
supervised_keys=("text", "label"),
homepage=_URL,
citation=_CITATION,
task_templates=[
TextClassification(
text_column="text",
label_column="label",
)
],
)

def _split_generators(self, dl_manager):
Expand Down
2 changes: 2 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset":
columns_to_drop = [column for column in self.column_names if column not in column_mapping]
dataset = self.remove_columns(columns_to_drop)
dataset = dataset.rename_columns(column_mapping)
# We found a template so now flush `DatasetInfo` to skip the template update in `DatasetInfo.__post_init__`
dataset.info.task_templates = None
dataset = dataset.cast(features=template.features)
return dataset

Expand Down
11 changes: 10 additions & 1 deletion src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union

from datasets.tasks.text_classification import TextClassification

from . import config
from .features import Features, Value
from .features import ClassLabel, Features, Value
from .splits import SplitDict
from .tasks import TaskTemplate, task_template_from_dict
from .utils import Version
Expand Down Expand Up @@ -166,6 +168,13 @@ def __post_init__(self):
else:
template = task_template_from_dict(self.task_templates)
self.task_templates = [template] if template is not None else []
# Insert labels and mappings for text classification
for idx, template in enumerate(self.task_templates):
if isinstance(template, TextClassification) and self.features is not None:
# Cast labels to tuple to enable hashing of task template
object.__setattr__(template, "labels", tuple(sorted(self.features[template.label_column].names)))
template.label_schema["labels"] = ClassLabel(names=template.labels)
self.task_templates[idx] = template

def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)
Expand Down
19 changes: 12 additions & 7 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@ class TextClassification(TaskTemplate):
# TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
# investigate if there's a more elegant approach.
label_schema = Features({"labels": ClassLabel})
labels: List[str]
labels: List[str] = None
text_column: str = "text"
label_column: str = "labels"

def __post_init__(self):
assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique"
# Cast labels to tuple to allow hashing
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
object.__setattr__(self, "text_column", self.text_column)
object.__setattr__(self, "label_column", self.label_column)
self.label_schema["labels"] = ClassLabel(names=self.labels)
object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)})
object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()})
if self.labels:
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
self.label_schema["labels"] = ClassLabel(names=self.labels)

@property
def column_mapping(self) -> Dict[str, str]:
Expand All @@ -40,3 +37,11 @@ def from_dict(cls, template_dict: dict) -> "TextClassification":
label_column=template_dict["label_column"],
labels=template_dict["labels"],
)

@property
def label2id(self):
return {label: idx for idx, label in enumerate(self.labels)}

@property
def id2label(self):
return {idx: label for idx, label in enumerate(self.labels)}
133 changes: 119 additions & 14 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,11 @@ def test_concatenate_with_no_task_templates(self, in_memory):
self.assertEqual(dset_concat.info.task_templates, None)

def test_concatenate_with_equal_task_templates(self, in_memory):
task_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"])
info = DatasetInfo(task_templates=task_template)
task_template = TextClassification(text_column="text", label_column="labels")
info = DatasetInfo(
features=Features({"text": Value("string"), "labels": ClassLabel(names=["pos", "neg"])}),
task_templates=task_template,
)
data = {"text": ["i love transformers!"], "labels": [1]}
with tempfile.TemporaryDirectory() as tmp_dir:
with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict(
Expand All @@ -716,10 +719,42 @@ def test_concatenate_with_equal_task_templates(self, in_memory):
self.assertListEqual(dset_concat.info.task_templates, [task_template])

def test_concatenate_with_mixed_task_templates_in_common(self, in_memory):
tc_template = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"])
tc_template = TextClassification(text_column="text", label_column="labels")
qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers")
info1 = DatasetInfo(task_templates=[qa_template])
info2 = DatasetInfo(task_templates=[qa_template, tc_template])
info1 = DatasetInfo(
task_templates=[qa_template],
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
)
info2 = DatasetInfo(
task_templates=[qa_template, tc_template],
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
)
data = {
"text": ["i love transformers!"],
"labels": [1],
Expand All @@ -738,15 +773,67 @@ def test_concatenate_with_mixed_task_templates_in_common(self, in_memory):
self.assertListEqual(dset_concat.info.task_templates, [qa_template])

def test_concatenate_with_no_mixed_task_templates_in_common(self, in_memory):
tc_template1 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg"])
tc_template2 = TextClassification(text_column="text", label_column="labels", labels=["pos", "neg", "neutral"])
tc_template1 = TextClassification(text_column="text", label_column="labels")
tc_template2 = TextClassification(text_column="text", label_column="sentiment")
qa_template = QuestionAnswering(question_column="question", context_column="context", answers_column="answers")
info1 = DatasetInfo(task_templates=[tc_template1])
info2 = DatasetInfo(task_templates=[tc_template2])
info3 = DatasetInfo(task_templates=[qa_template])
info1 = DatasetInfo(
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
task_templates=[tc_template1],
)
info2 = DatasetInfo(
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
task_templates=[tc_template2],
)
info3 = DatasetInfo(
features=Features(
{
"text": Value("string"),
"labels": ClassLabel(names=["pos", "neg"]),
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
"context": Value("string"),
"question": Value("string"),
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
),
}
),
task_templates=[qa_template],
)
data = {
"text": ["i love transformers!"],
"labels": [1],
"sentiment": [0],
"context": ["huggingface is going to the moon!"],
"question": ["where is huggingface going?"],
"answers": [{"text": ["to the moon!"], "answer_start": [2]}],
Expand Down Expand Up @@ -1929,7 +2016,7 @@ def test_task_text_classification(self, in_memory):
features_before_cast = Features(
{
"input_text": Value("string"),
"input_labels": Value("int32"),
"input_labels": ClassLabel(names=labels),
}
)
# Labels are cast to tuple during TextClassification init, so we do the same here
Expand All @@ -1939,7 +2026,7 @@ def test_task_text_classification(self, in_memory):
"labels": ClassLabel(names=tuple(labels)),
}
)
task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels)
task = TextClassification(text_column="input_text", label_column="input_labels")
info = DatasetInfo(
features=features_before_cast,
task_templates=task,
Expand Down Expand Up @@ -2030,10 +2117,10 @@ def test_task_with_incompatible_templates(self, in_memory):
features = Features(
{
"input_text": Value("string"),
"input_labels": Value("int32"),
"input_labels": ClassLabel(names=labels),
}
)
task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels)
task = TextClassification(text_column="input_text", label_column="input_labels")
info = DatasetInfo(
features=features,
task_templates=task,
Expand All @@ -2050,6 +2137,24 @@ def test_task_with_incompatible_templates(self, in_memory):
# Invalid task type
dset.prepare_for_task(1)

def test_task_templates_empty_after_preparation(self, in_memory):
features = Features(
{
"input_text": Value("string"),
"input_labels": ClassLabel(names=["pos", "neg"]),
}
)
task = TextClassification(text_column="input_text", label_column="input_labels")
info = DatasetInfo(
features=features,
task_templates=task,
)
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
with dset.prepare_for_task(task="text-classification") as dset:
self.assertIsNone(dset.info.task_templates)


class MiscellaneousDatasetTest(TestCase):
def test_from_pandas(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_column_mapping(self):

def test_from_dict(self):
input_schema = Features({"text": Value("string")})
# Labels are cast to tuple during TextClassification init, so we do the same here
# Labels are cast to tuple during TextClassification __post_init__, so we do the same here
label_schema = Features({"labels": ClassLabel(names=tuple(self.labels))})
template_dict = {"text_column": "input_text", "label_column": "input_labels", "labels": self.labels}
task = TextClassification.from_dict(template_dict)
Expand Down