Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/datasets/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import abc
import dataclasses
from dataclasses import dataclass
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Type, TypeVar

from ..features import Features


T = TypeVar("T", bound="TaskTemplate")


@dataclass(frozen=True)
class TaskTemplate(abc.ABC):
task: ClassVar[str]
Expand All @@ -18,9 +22,9 @@ def features(self) -> Features:
@property
@abc.abstractmethod
def column_mapping(self) -> Dict[str, str]:
return NotImplemented
raise NotImplementedError

@classmethod
@abc.abstractmethod
def from_dict(cls, template_dict: dict) -> "TaskTemplate":
return NotImplemented
def from_dict(cls: Type[T], template_dict: dict) -> T:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very elegant approach - thanks!

field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in template_dict.items() if k in field_names})
21 changes: 4 additions & 17 deletions src/datasets/tasks/question_answering.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from dataclasses import dataclass
from typing import Dict
from typing import ClassVar, Dict

from ..features import Features, Sequence, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class QuestionAnswering(TaskTemplate):
task = "question-answering"
input_schema = Features({"question": Value("string"), "context": Value("string")})
label_schema = Features(
task: ClassVar[str] = "question-answering"
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
label_schema: ClassVar[Features] = Features(
{
"answers": Sequence(
{
Expand All @@ -23,19 +23,6 @@ class QuestionAnswering(TaskTemplate):
context_column: str = "context"
answers_column: str = "answers"

def __post_init__(self):
object.__setattr__(self, "question_column", self.question_column)
object.__setattr__(self, "context_column", self.context_column)
object.__setattr__(self, "answers_column", self.answers_column)

@property
def column_mapping(self) -> Dict[str, str]:
return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"}

@classmethod
def from_dict(cls, template_dict: dict) -> "QuestionAnswering":
return cls(
question_column=template_dict["question_column"],
context_column=template_dict["context_column"],
answers_column=template_dict["answers_column"],
)
30 changes: 13 additions & 17 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
from dataclasses import dataclass
from typing import Dict, List
from typing import ClassVar, Dict, List

from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class TextClassification(TaskTemplate):
task = "text-classification"
input_schema = Features({"text": Value("string")})
task: ClassVar[str] = "text-classification"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for spotting this!

input_schema: ClassVar[Features] = Features({"text": Value("string")})
# TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
# investigate if there's a more elegant approach.
label_schema = Features({"labels": ClassLabel})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
labels: List[str]
text_column: str = "text"
label_column: str = "labels"

def __post_init__(self):
assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique"
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
object.__setattr__(self, "text_column", self.text_column)
object.__setattr__(self, "label_column", self.label_column)
self.__dict__["labels"] = tuple(sorted(self.labels))
self.label_schema["labels"] = ClassLabel(names=self.labels)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is an issue. Modifying a class attribute based on the value of an instance attribute doesn't make sense if multiple instances of the same class are allowed (if that's the case, the class attribute will have a valid value only for the instance that was initialized last). One way to fix this is with the help of descriptors:

class FeaturesWithLazyClassLabel:
    def __init__(self, features, label_column="labels"):
        assert label_column in features, f"Key '{label_column}' missing in features {features}"
        self._features = features
        self._label_column = label_column

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self._features
     
        assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'"
        # note: this part is not cached, but we can add this easily
        features = self._features.copy()
        features["labels"] = ClassLabel(names=getattr(obj, self._label_column))
        return features

and then in TextClassification:

label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))

Copy link
Member

@lhoestq lhoestq May 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this !

The fix using FeaturesWithLazyClassLabel is fine IMO (though it would be nice if we find a simpler way to fix this)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch @mariosasko! i'm also happy with the FeatureWithLazyClassLabel fix for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a temporary solution. I agree we should try to find a simpler way to fix this. The current API design does not seem to fit the task very well (having label_schema defined as a class attribute but depends on the value of instances?), so we should rethink that part IMO.

object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)})
object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()})

@property
def column_mapping(self) -> Dict[str, str]:
Expand All @@ -33,10 +29,10 @@ def column_mapping(self) -> Dict[str, str]:
self.label_column: "labels",
}

@classmethod
def from_dict(cls, template_dict: dict) -> "TextClassification":
return cls(
text_column=template_dict["text_column"],
label_column=template_dict["label_column"],
labels=template_dict["labels"],
)
@property
def label2id(self):
return {label: idx for idx, label in enumerate(self.labels)}

@property
def id2label(self):
return {idx: label for idx, label in enumerate(self.labels)}