Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/datasets/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import abc
import dataclasses
from dataclasses import dataclass
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Type, TypeVar

from ..features import Features


T = TypeVar("T", bound="TaskTemplate")


@dataclass(frozen=True)
class TaskTemplate(abc.ABC):
task: ClassVar[str]
Expand All @@ -18,9 +22,9 @@ def features(self) -> Features:
@property
@abc.abstractmethod
def column_mapping(self) -> Dict[str, str]:
return NotImplemented
raise NotImplementedError

@classmethod
@abc.abstractmethod
def from_dict(cls, template_dict: dict) -> "TaskTemplate":
return NotImplemented
def from_dict(cls: Type[T], template_dict: dict) -> T:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very elegant approach - thanks!

field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in template_dict.items() if k in field_names})
21 changes: 4 additions & 17 deletions src/datasets/tasks/question_answering.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from dataclasses import dataclass
from typing import Dict
from typing import ClassVar, Dict

from ..features import Features, Sequence, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class QuestionAnswering(TaskTemplate):
task = "question-answering"
input_schema = Features({"question": Value("string"), "context": Value("string")})
label_schema = Features(
task: ClassVar[str] = "question-answering"
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
label_schema: ClassVar[Features] = Features(
{
"answers": Sequence(
{
Expand All @@ -23,19 +23,6 @@ class QuestionAnswering(TaskTemplate):
context_column: str = "context"
answers_column: str = "answers"

def __post_init__(self):
object.__setattr__(self, "question_column", self.question_column)
object.__setattr__(self, "context_column", self.context_column)
object.__setattr__(self, "answers_column", self.answers_column)

@property
def column_mapping(self) -> Dict[str, str]:
return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"}

@classmethod
def from_dict(cls, template_dict: dict) -> "QuestionAnswering":
return cls(
question_column=template_dict["question_column"],
context_column=template_dict["context_column"],
answers_column=template_dict["answers_column"],
)
50 changes: 30 additions & 20 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
from dataclasses import dataclass
from typing import Dict, List
from typing import ClassVar, Dict, List

from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


class FeaturesWithLazyClassLabel:
def __init__(self, features, label_column="labels"):
assert label_column in features, f"Key '{label_column}' missing in features {features}"
self._features = features
self._label_column = label_column

def __get__(self, obj, objtype=None):
if obj is None:
return self._features

assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'"
features = self._features.copy()
features["labels"] = ClassLabel(names=getattr(obj, self._label_column))
return features


@dataclass(frozen=True)
class TextClassification(TaskTemplate):
task = "text-classification"
input_schema = Features({"text": Value("string")})
# TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
# investigate if there's a more elegant approach.
label_schema = Features({"labels": ClassLabel})
task: ClassVar[str] = "text-classification"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for spotting this!

input_schema: ClassVar[Features] = Features({"text": Value("string")})
# TODO(lewtun): Find a more elegant approach without descriptors.
label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))
labels: List[str]
text_column: str = "text"
label_column: str = "labels"

def __post_init__(self):
assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique"
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
object.__setattr__(self, "text_column", self.text_column)
object.__setattr__(self, "label_column", self.label_column)
self.label_schema["labels"] = ClassLabel(names=self.labels)
object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)})
object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()})
self.__dict__["labels"] = tuple(sorted(self.labels))

@property
def column_mapping(self) -> Dict[str, str]:
Expand All @@ -33,10 +43,10 @@ def column_mapping(self) -> Dict[str, str]:
self.label_column: "labels",
}

@classmethod
def from_dict(cls, template_dict: dict) -> "TextClassification":
return cls(
text_column=template_dict["text_column"],
label_column=template_dict["label_column"],
labels=template_dict["labels"],
)
@property
def label2id(self):
return {label: idx for idx, label in enumerate(self.labels)}

@property
def id2label(self):
return {idx: label for idx, label in enumerate(self.labels)}