-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Update text classification template labels in DatasetInfo __post_init__ #2392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
538f3be
c02a2e4
2eab30c
feaca48
188d02c
1e3e830
635e54d
a892fde
d85d73d
43f9d55
5d66b4f
e7b1f7a
6f3ff6d
1bf0b5b
0dda59e
654b2b0
812bd87
159a6f6
a580339
cff9d52
8146867
fa53dc5
f78d5c4
514890d
71d5ac8
40e4400
321e4e6
e79dfe4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -36,6 +36,8 @@ | |||||||||||||||||||||||||||
| from dataclasses import asdict, dataclass, field | ||||||||||||||||||||||||||||
| from typing import List, Optional, Union | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from datasets.tasks.text_classification import TextClassification | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from . import config | ||||||||||||||||||||||||||||
| from .features import Features, Value | ||||||||||||||||||||||||||||
| from .splits import SplitDict | ||||||||||||||||||||||||||||
|
|
@@ -154,6 +156,7 @@ def __post_init__(self): | |||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| self.supervised_keys = SupervisedKeysData(**self.supervised_keys) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Parse and make a list of templates | ||||||||||||||||||||||||||||
| if self.task_templates is not None: | ||||||||||||||||||||||||||||
| if isinstance(self.task_templates, (list, tuple)): | ||||||||||||||||||||||||||||
| templates = [ | ||||||||||||||||||||||||||||
|
|
@@ -167,6 +170,16 @@ def __post_init__(self): | |||||||||||||||||||||||||||
| template = task_template_from_dict(self.task_templates) | ||||||||||||||||||||||||||||
| self.task_templates = [template] if template is not None else [] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Insert labels and mappings for text classification | ||||||||||||||||||||||||||||
| if self.task_templates is not None: | ||||||||||||||||||||||||||||
| self.task_templates = list(self.task_templates) | ||||||||||||||||||||||||||||
lhoestq marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| for idx, template in enumerate(self.task_templates): | ||||||||||||||||||||||||||||
| if isinstance(template, TextClassification) and self.features is not None: | ||||||||||||||||||||||||||||
| labels = self.features[template.label_column].names | ||||||||||||||||||||||||||||
| self.task_templates[idx] = TextClassification( | ||||||||||||||||||||||||||||
| text_column=template.text_column, label_column=template.label_column, labels=labels | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
lhoestq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| for idx, template in enumerate(self.task_templates): | |
| if isinstance(template, TextClassification) and self.features is not None: | |
| labels = self.features[template.label_column].names | |
| self.task_templates[idx] = TextClassification( | |
| text_column=template.text_column, label_column=template.label_column, labels=labels | |
| ) | |
| if self.features is not None: | |
| for idx, template in enumerate(self.task_templates): | |
| if isinstance(template, TextClassification): | |
| labels = self.features[template.label_column].names | |
| self.task_templates[idx] = TextClassification( | |
| text_column=template.text_column, label_column=template.label_column, labels=labels | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea! done.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import copy | ||
| from dataclasses import dataclass | ||
| from typing import ClassVar, Dict, List | ||
| from typing import ClassVar, Dict, Optional, Tuple | ||
|
|
||
| from ..features import ClassLabel, Features, Value | ||
| from .base import TaskTemplate | ||
|
|
@@ -23,30 +24,26 @@ def __get__(self, obj, objtype=None): | |
|
|
||
| @dataclass(frozen=True) | ||
| class TextClassification(TaskTemplate): | ||
| task: ClassVar[str] = "text-classification" | ||
| # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization | ||
| task: str = "text-classification" | ||
lhoestq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| input_schema: ClassVar[Features] = Features({"text": Value("string")}) | ||
| # TODO(lewtun): Find a more elegant approach without descriptors. | ||
| label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) | ||
| labels: List[str] | ||
| text_column: str = "text" | ||
| label_column: str = "labels" | ||
| labels: Optional[Tuple[str]] = None | ||
|
|
||
| def __post_init__(self): | ||
| assert len(self.labels) == len(set(self.labels)), "Labels must be unique" | ||
| # Cast labels to tuple to allow hashing | ||
| self.__dict__["labels"] = tuple(sorted(self.labels)) | ||
| if self.labels: | ||
| assert len(self.labels) == len(set(self.labels)), "Labels must be unique" | ||
| # Cast labels to tuple to allow hashing | ||
| self.__dict__["labels"] = tuple(sorted(self.labels)) | ||
| self.__dict__["label_schema"] = copy.deepcopy(self.label_schema) | ||
|
||
| self.label_schema["labels"] = ClassLabel(names=self.labels) | ||
|
|
||
| @property | ||
| def column_mapping(self) -> Dict[str, str]: | ||
| return { | ||
| self.text_column: "text", | ||
| self.label_column: "labels", | ||
| } | ||
|
|
||
| @property | ||
| def label2id(self): | ||
| return {label: idx for idx, label in enumerate(self.labels)} | ||
|
|
||
| @property | ||
| def id2label(self): | ||
| return {idx: label for idx, label in enumerate(self.labels)} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from typing import Callable | ||
|
|
||
|
|
||
| def is_documented_by(function_with_docstring: Callable): | ||
| """Decorator to share docstrings across common functions. | ||
|
|
||
| Args: | ||
| function_with_docstring (`Callable`): Name of the function with the docstring. | ||
| """ | ||
|
|
||
| def wrapper(target_function): | ||
| target_function.__doc__ = function_with_docstring.__doc__ | ||
| return target_function | ||
|
|
||
| return wrapper |
Uh oh!
There was an error while loading. Please reload this page.