-
Notifications
You must be signed in to change notification settings - Fork 3k
Improve task api code quality #2376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,30 +1,26 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Dict, List | ||
| from typing import ClassVar, Dict, List | ||
|
|
||
| from ..features import ClassLabel, Features, Value | ||
| from .base import TaskTemplate | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TextClassification(TaskTemplate): | ||
| task = "text-classification" | ||
| input_schema = Features({"text": Value("string")}) | ||
| task: ClassVar[str] = "text-classification" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for spotting this! |
||
| input_schema: ClassVar[Features] = Features({"text": Value("string")}) | ||
| # TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so | ||
| # investigate if there's a more elegant approach. | ||
| label_schema = Features({"labels": ClassLabel}) | ||
| label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) | ||
| labels: List[str] | ||
| text_column: str = "text" | ||
| label_column: str = "labels" | ||
|
|
||
| def __post_init__(self): | ||
| assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique" | ||
| assert len(self.labels) == len(set(self.labels)), "Labels must be unique" | ||
| # Cast labels to tuple to allow hashing | ||
| object.__setattr__(self, "labels", tuple(sorted(self.labels))) | ||
| object.__setattr__(self, "text_column", self.text_column) | ||
| object.__setattr__(self, "label_column", self.label_column) | ||
| self.__dict__["labels"] = tuple(sorted(self.labels)) | ||
| self.label_schema["labels"] = ClassLabel(names=self.labels) | ||
|
||
| object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)}) | ||
| object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()}) | ||
|
|
||
| @property | ||
| def column_mapping(self) -> Dict[str, str]: | ||
|
|
@@ -33,10 +29,10 @@ def column_mapping(self) -> Dict[str, str]: | |
| self.label_column: "labels", | ||
| } | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, template_dict: dict) -> "TextClassification": | ||
| return cls( | ||
| text_column=template_dict["text_column"], | ||
| label_column=template_dict["label_column"], | ||
| labels=template_dict["labels"], | ||
| ) | ||
| @property | ||
| def label2id(self): | ||
| return {label: idx for idx, label in enumerate(self.labels)} | ||
|
|
||
| @property | ||
| def id2label(self): | ||
| return {idx: label for idx, label in enumerate(self.labels)} | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very elegant approach - thanks!