Skip to content

Commit 633ddcd

Browse files
authored
Improve task api code quality (#2376)
* Improve task api code quality * Add todo deleted by accident * Lazy initialize label schema in text classification task
1 parent 74751e3 commit 633ddcd

3 files changed

Lines changed: 43 additions & 42 deletions

File tree

src/datasets/tasks/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import abc
2+
import dataclasses
23
from dataclasses import dataclass
3-
from typing import ClassVar, Dict
4+
from typing import ClassVar, Dict, Type, TypeVar
45

56
from ..features import Features
67

78

9+
T = TypeVar("T", bound="TaskTemplate")
10+
11+
812
@dataclass(frozen=True)
913
class TaskTemplate(abc.ABC):
1014
task: ClassVar[str]
@@ -18,9 +22,9 @@ def features(self) -> Features:
1822
@property
1923
@abc.abstractmethod
2024
def column_mapping(self) -> Dict[str, str]:
21-
return NotImplemented
25+
raise NotImplementedError
2226

2327
@classmethod
24-
@abc.abstractmethod
25-
def from_dict(cls, template_dict: dict) -> "TaskTemplate":
26-
return NotImplemented
28+
def from_dict(cls: Type[T], template_dict: dict) -> T:
29+
field_names = set(f.name for f in dataclasses.fields(cls))
30+
return cls(**{k: v for k, v in template_dict.items() if k in field_names})
Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from dataclasses import dataclass
2-
from typing import Dict
2+
from typing import ClassVar, Dict
33

44
from ..features import Features, Sequence, Value
55
from .base import TaskTemplate
66

77

88
@dataclass(frozen=True)
99
class QuestionAnswering(TaskTemplate):
10-
task = "question-answering"
11-
input_schema = Features({"question": Value("string"), "context": Value("string")})
12-
label_schema = Features(
10+
task: ClassVar[str] = "question-answering"
11+
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
12+
label_schema: ClassVar[Features] = Features(
1313
{
1414
"answers": Sequence(
1515
{
@@ -23,19 +23,6 @@ class QuestionAnswering(TaskTemplate):
2323
context_column: str = "context"
2424
answers_column: str = "answers"
2525

26-
def __post_init__(self):
27-
object.__setattr__(self, "question_column", self.question_column)
28-
object.__setattr__(self, "context_column", self.context_column)
29-
object.__setattr__(self, "answers_column", self.answers_column)
30-
3126
@property
3227
def column_mapping(self) -> Dict[str, str]:
3328
return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"}
34-
35-
@classmethod
36-
def from_dict(cls, template_dict: dict) -> "QuestionAnswering":
37-
return cls(
38-
question_column=template_dict["question_column"],
39-
context_column=template_dict["context_column"],
40-
answers_column=template_dict["answers_column"],
41-
)
Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,40 @@
11
from dataclasses import dataclass
2-
from typing import Dict, List
2+
from typing import ClassVar, Dict, List
33

44
from ..features import ClassLabel, Features, Value
55
from .base import TaskTemplate
66

77

8+
class FeaturesWithLazyClassLabel:
9+
def __init__(self, features, label_column="labels"):
10+
assert label_column in features, f"Key '{label_column}' missing in features {features}"
11+
self._features = features
12+
self._label_column = label_column
13+
14+
def __get__(self, obj, objtype=None):
15+
if obj is None:
16+
return self._features
17+
18+
assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'"
19+
features = self._features.copy()
20+
features["labels"] = ClassLabel(names=getattr(obj, self._label_column))
21+
return features
22+
23+
824
@dataclass(frozen=True)
925
class TextClassification(TaskTemplate):
10-
task = "text-classification"
11-
input_schema = Features({"text": Value("string")})
12-
# TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
13-
# investigate if there's a more elegant approach.
14-
label_schema = Features({"labels": ClassLabel})
26+
task: ClassVar[str] = "text-classification"
27+
input_schema: ClassVar[Features] = Features({"text": Value("string")})
28+
# TODO(lewtun): Find a more elegant approach without descriptors.
29+
label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))
1530
labels: List[str]
1631
text_column: str = "text"
1732
label_column: str = "labels"
1833

1934
def __post_init__(self):
20-
assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique"
35+
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
2136
# Cast labels to tuple to allow hashing
22-
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
23-
object.__setattr__(self, "text_column", self.text_column)
24-
object.__setattr__(self, "label_column", self.label_column)
25-
self.label_schema["labels"] = ClassLabel(names=self.labels)
26-
object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)})
27-
object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()})
37+
self.__dict__["labels"] = tuple(sorted(self.labels))
2838

2939
@property
3040
def column_mapping(self) -> Dict[str, str]:
@@ -33,10 +43,10 @@ def column_mapping(self) -> Dict[str, str]:
3343
self.label_column: "labels",
3444
}
3545

36-
@classmethod
37-
def from_dict(cls, template_dict: dict) -> "TextClassification":
38-
return cls(
39-
text_column=template_dict["text_column"],
40-
label_column=template_dict["label_column"],
41-
labels=template_dict["labels"],
42-
)
46+
@property
47+
def label2id(self):
48+
return {label: idx for idx, label in enumerate(self.labels)}
49+
50+
@property
51+
def id2label(self):
52+
return {idx: label for idx, label in enumerate(self.labels)}

0 commit comments

Comments
 (0)