Skip to content

Commit 08dfdc9

Browse files
mariosaskolhoestq
authored andcommitted
Revert task removal in folder-based builders (#5051)
* Add AudioClassification task * Add classification task to folder based builders * Fix tests * Minor fix * Minor fix again
1 parent 0c84b71 commit 08dfdc9

File tree

8 files changed

+79
-2
lines changed

8 files changed

+79
-2
lines changed

docs/source/package_reference/task_templates.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ The tasks supported by [`Dataset.prepare_for_task`] and [`DatasetDict.prepare_fo
44

55
[[autodoc]] datasets.tasks.AutomaticSpeechRecognition
66

7+
[[autodoc]] datasets.tasks.AudioClassification
8+
79
[[autodoc]] datasets.tasks.ImageClassification
810
- align_with_features
911

src/datasets/packaged_modules/audiofolder/audiofolder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List
22

33
import datasets
4+
from datasets.tasks import AudioClassification
45

56
from ..folder_based_builder import folder_based_builder
67

@@ -20,6 +21,7 @@ class AudioFolder(folder_based_builder.FolderBasedBuilder):
2021
BASE_COLUMN_NAME = "audio"
2122
BUILDER_CONFIG_CLASS = AudioFolderConfig
2223
EXTENSIONS: List[str] # definition at the bottom of the script
24+
CLASSIFICATION_TASK = AudioClassification(audio_column="audio", label_column="label")
2325

2426

2527
# Obtained with:

src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import itertools
33
import os
44
from dataclasses import dataclass
5-
from typing import Any, List, Optional, Tuple
5+
from typing import List, Optional, Tuple
66

77
import pandas as pd
88
import pyarrow as pa
99
import pyarrow.compute as pc
1010
import pyarrow.json as paj
1111

1212
import datasets
13+
from datasets.features.features import FeatureType
14+
from datasets.tasks.base import TaskTemplate
1315

1416

1517
logger = datasets.utils.logging.get_logger(__name__)
@@ -62,12 +64,14 @@ class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
6264
BUILDER_CONFIG_CLASS: builder config inherited from `folder_based_builder.FolderBasedBuilderConfig`
6365
EXTENSIONS: list of allowed extensions (only files with these extensions and METADATA_FILENAME files
6466
will be included in a dataset)
67+
CLASSIFICATION_TASK: classification task to use if labels are obtained from the folder structure
6568
"""
6669

67-
BASE_FEATURE: Any
70+
BASE_FEATURE: FeatureType
6871
BASE_COLUMN_NAME: str
6972
BUILDER_CONFIG_CLASS: FolderBasedBuilderConfig
7073
EXTENSIONS: List[str]
74+
CLASSIFICATION_TASK: TaskTemplate
7175

7276
SKIP_CHECKSUM_COMPUTATION_BY_DEFAULT: bool = True
7377
METADATA_FILENAMES: List[str] = ["metadata.csv", "metadata.jsonl"]
@@ -214,6 +218,7 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
214218
"label": datasets.ClassLabel(names=sorted(labels)),
215219
}
216220
)
221+
self.info.task_templates = [self.CLASSIFICATION_TASK.align_with_features(self.info.features)]
217222
else:
218223
self.info.features = datasets.Features({self.BASE_COLUMN_NAME: self.BASE_FEATURE})
219224

src/datasets/packaged_modules/imagefolder/imagefolder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List
22

33
import datasets
4+
from datasets.tasks import ImageClassification
45

56
from ..folder_based_builder import folder_based_builder
67

@@ -20,6 +21,7 @@ class ImageFolder(folder_based_builder.FolderBasedBuilder):
2021
BASE_COLUMN_NAME = "image"
2122
BUILDER_CONFIG_CLASS = ImageFolderConfig
2223
EXTENSIONS: List[str] # definition at the bottom of the script
24+
CLASSIFICATION_TASK = ImageClassification(image_column="image", label_column="label")
2325

2426

2527
# Obtained with:

src/datasets/tasks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
from ..utils.logging import get_logger
4+
from .audio_classificiation import AudioClassification
45
from .automatic_speech_recognition import AutomaticSpeechRecognition
56
from .base import TaskTemplate
67
from .image_classification import ImageClassification
@@ -12,6 +13,7 @@
1213

1314
__all__ = [
1415
"AutomaticSpeechRecognition",
16+
"AudioClassification",
1517
"ImageClassification",
1618
"LanguageModeling",
1719
"QuestionAnsweringExtractive",
@@ -25,6 +27,7 @@
2527

2628
NAME2TEMPLATE = {
2729
AutomaticSpeechRecognition.task: AutomaticSpeechRecognition,
30+
AudioClassification.task: AudioClassification,
2831
ImageClassification.task: ImageClassification,
2932
LanguageModeling.task: LanguageModeling,
3033
QuestionAnsweringExtractive.task: QuestionAnsweringExtractive,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import copy
2+
from dataclasses import dataclass
3+
from typing import ClassVar, Dict
4+
5+
from ..features import Audio, ClassLabel, Features
6+
from .base import TaskTemplate
7+
8+
9+
@dataclass(frozen=True)
10+
class AudioClassification(TaskTemplate):
11+
task: str = "audio-classification"
12+
input_schema: ClassVar[Features] = Features({"audio": Audio()})
13+
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
14+
audio_column: str = "audio"
15+
label_column: str = "labels"
16+
17+
def align_with_features(self, features):
18+
if self.label_column not in features:
19+
raise ValueError(f"Column {self.label_column} is not present in features.")
20+
if not isinstance(features[self.label_column], ClassLabel):
21+
raise ValueError(f"Column {self.label_column} is not a ClassLabel.")
22+
task_template = copy.deepcopy(self)
23+
label_schema = self.label_schema.copy()
24+
label_schema["labels"] = features[self.label_column]
25+
task_template.__dict__["label_schema"] = label_schema
26+
return task_template
27+
28+
@property
29+
def column_mapping(self) -> Dict[str, str]:
30+
return {
31+
self.audio_column: "audio",
32+
self.label_column: "labels",
33+
}

tests/packaged_modules/test_folder_based_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
FolderBasedBuilder,
1212
FolderBasedBuilderConfig,
1313
)
14+
from datasets.tasks import TextClassification
1415

1516

1617
class DummyFolderBasedBuilder(FolderBasedBuilder):
1718
BASE_FEATURE = None
1819
BASE_COLUMN_NAME = "base"
1920
BUILDER_CONFIG_CLASS = FolderBasedBuilderConfig
2021
EXTENSIONS = [".txt"]
22+
CLASSIFICATION_TASK = TextClassification(text_column="base", label_column="label")
2123

2224

2325
@pytest.fixture

tests/test_tasks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value
66
from datasets.info import DatasetInfo
77
from datasets.tasks import (
8+
AudioClassification,
89
AutomaticSpeechRecognition,
910
ImageClassification,
1011
LanguageModeling,
@@ -126,6 +127,33 @@ def test_from_dict(self):
126127
self.assertEqual(label_schema, task.label_schema)
127128

128129

130+
class AudioClassificationTest(TestCase):
131+
def setUp(self):
132+
self.labels = sorted(["pos", "neg"])
133+
134+
def test_column_mapping(self):
135+
task = AudioClassification(audio_column="input_audio", label_column="input_label")
136+
self.assertDictEqual({"input_audio": "audio", "input_label": "labels"}, task.column_mapping)
137+
138+
def test_from_dict(self):
139+
input_schema = Features({"audio": Audio()})
140+
label_schema = Features({"labels": ClassLabel})
141+
template_dict = {
142+
"audio_column": "input_image",
143+
"label_column": "input_label",
144+
}
145+
task = AudioClassification.from_dict(template_dict)
146+
self.assertEqual("audio-classification", task.task)
147+
self.assertEqual(input_schema, task.input_schema)
148+
self.assertEqual(label_schema, task.label_schema)
149+
150+
def test_align_with_features(self):
151+
task = AudioClassification(audio_column="input_audio", label_column="input_label")
152+
self.assertEqual(task.label_schema["labels"], ClassLabel)
153+
task = task.align_with_features(Features({"input_label": ClassLabel(names=self.labels)}))
154+
self.assertEqual(task.label_schema["labels"], ClassLabel(names=self.labels))
155+
156+
129157
class ImageClassificationTest(TestCase):
130158
def setUp(self):
131159
self.labels = sorted(["pos", "neg"])

0 commit comments

Comments
 (0)