Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/package_reference/task_templates.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ The tasks supported by [`Dataset.prepare_for_task`] and [`DatasetDict.prepare_fo

[[autodoc]] datasets.tasks.AutomaticSpeechRecognition

[[autodoc]] datasets.tasks.AudioClassification

[[autodoc]] datasets.tasks.ImageClassification
- align_with_features

Expand Down
2 changes: 2 additions & 0 deletions src/datasets/packaged_modules/audiofolder/audiofolder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

import datasets
from datasets.tasks import AudioClassification

from ..folder_based_builder import folder_based_builder

Expand All @@ -20,6 +21,7 @@ class AudioFolder(folder_based_builder.FolderBasedBuilder):
BASE_COLUMN_NAME = "audio"
BUILDER_CONFIG_CLASS = AudioFolderConfig
EXTENSIONS: List[str] # definition at the bottom of the script
CLASSIFICATION_TASK = AudioClassification(audio_column="audio", label_column="label")


# Obtained with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyarrow.json as paj

import datasets
from datasets.tasks.base import TaskTemplate


logger = datasets.utils.logging.get_logger(__name__)
Expand Down Expand Up @@ -62,12 +63,14 @@ class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
BUILDER_CONFIG_CLASS: builder config inherited from `folder_based_builder.FolderBasedBuilderConfig`
EXTENSIONS: list of allowed extensions (only files with these extensions and METADATA_FILENAME files
will be included in a dataset)
CLASSIFICATION_TASK: classification task to use if labels are obtained from the folder structure
"""

BASE_FEATURE: Any
BASE_COLUMN_NAME: str
BUILDER_CONFIG_CLASS: FolderBasedBuilderConfig
EXTENSIONS: List[str]
CLASSIFICATION_TASK: TaskTemplate

SKIP_CHECKSUM_COMPUTATION_BY_DEFAULT: bool = True
METADATA_FILENAMES: List[str] = ["metadata.csv", "metadata.jsonl"]
Expand Down Expand Up @@ -214,6 +217,7 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
"label": datasets.ClassLabel(names=sorted(labels)),
}
)
self.info.task_templates = [self.CLASSIFICATION_TASK]
else:
self.info.features = datasets.Features({self.BASE_COLUMN_NAME: self.BASE_FEATURE})

Expand Down
2 changes: 2 additions & 0 deletions src/datasets/packaged_modules/imagefolder/imagefolder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

import datasets
from datasets.tasks import ImageClassification

from ..folder_based_builder import folder_based_builder

Expand All @@ -20,6 +21,7 @@ class ImageFolder(folder_based_builder.FolderBasedBuilder):
BASE_COLUMN_NAME = "image"
BUILDER_CONFIG_CLASS = ImageFolderConfig
EXTENSIONS: List[str] # definition at the bottom of the script
CLASSIFICATION_TASK = ImageClassification(image_column="image", label_column="label")


# Obtained with:
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from ..utils.logging import get_logger
from .audio_classificiation import AudioClassification
from .automatic_speech_recognition import AutomaticSpeechRecognition
from .base import TaskTemplate
from .image_classification import ImageClassification
Expand All @@ -12,6 +13,7 @@

__all__ = [
"AutomaticSpeechRecognition",
"AudioClassification",
"ImageClassification",
"LanguageModeling",
"QuestionAnsweringExtractive",
Expand All @@ -25,6 +27,7 @@

NAME2TEMPLATE = {
AutomaticSpeechRecognition.task: AutomaticSpeechRecognition,
AudioClassification.task: AudioClassification,
ImageClassification.task: ImageClassification,
LanguageModeling.task: LanguageModeling,
QuestionAnsweringExtractive.task: QuestionAnsweringExtractive,
Expand Down
33 changes: 33 additions & 0 deletions src/datasets/tasks/audio_classificiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import copy
from dataclasses import dataclass
from typing import ClassVar, Dict

from ..features import Audio, ClassLabel, Features
from .base import TaskTemplate


@dataclass(frozen=True)
class AudioClassification(TaskTemplate):
task: str = "audio-classification"
input_schema: ClassVar[Features] = Features({"audio": Audio()})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
audio_column: str = "audio"
label_column: str = "labels"

def align_with_features(self, features):
if self.label_column not in features:
raise ValueError(f"Column {self.label_column} is not present in features.")
if not isinstance(features[self.label_column], ClassLabel):
raise ValueError(f"Column {self.label_column} is not a ClassLabel.")
task_template = copy.deepcopy(self)
label_schema = self.label_schema.copy()
label_schema["labels"] = features[self.label_column]
task_template.__dict__["label_schema"] = label_schema
return task_template

@property
def column_mapping(self) -> Dict[str, str]:
return {
self.audio_column: "audio",
self.label_column: "labels",
}
1 change: 1 addition & 0 deletions tests/packaged_modules/test_folder_based_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DummyFolderBasedBuilder(FolderBasedBuilder):
BASE_COLUMN_NAME = "base"
BUILDER_CONFIG_CLASS = FolderBasedBuilderConfig
EXTENSIONS = [".txt"]
CLASSIFICATION_TASK = None


@pytest.fixture
Expand Down
28 changes: 28 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value
from datasets.info import DatasetInfo
from datasets.tasks import (
AudioClassification,
AutomaticSpeechRecognition,
ImageClassification,
LanguageModeling,
Expand Down Expand Up @@ -126,6 +127,33 @@ def test_from_dict(self):
self.assertEqual(label_schema, task.label_schema)


class AudioClassificationTest(TestCase):
def setUp(self):
self.labels = sorted(["pos", "neg"])

def test_column_mapping(self):
task = AudioClassification(audio_column="input_audio", label_column="input_label")
self.assertDictEqual({"input_audio": "audio", "input_label": "labels"}, task.column_mapping)

def test_from_dict(self):
input_schema = Features({"audio": Audio()})
label_schema = Features({"labels": ClassLabel})
template_dict = {
"audio_column": "input_image",
"label_column": "input_label",
}
task = AudioClassification.from_dict(template_dict)
self.assertEqual("audio-classification", task.task)
self.assertEqual(input_schema, task.input_schema)
self.assertEqual(label_schema, task.label_schema)

def test_align_with_features(self):
task = AudioClassification(audio_column="input_audio", label_column="input_label")
self.assertEqual(task.label_schema["labels"], ClassLabel)
task = task.align_with_features(Features({"input_label": ClassLabel(names=self.labels)}))
self.assertEqual(task.label_schema["labels"], ClassLabel(names=self.labels))


class ImageClassificationTest(TestCase):
def setUp(self):
self.labels = sorted(["pos", "neg"])
Expand Down