Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/datasets/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..utils.logging import get_logger
from .automatic_speech_recognition import AutomaticSpeechRecognition
from .base import TaskTemplate
from .image_classification import ImageClassification
from .question_answering import QuestionAnsweringExtractive
from .summarization import Summarization
from .text_classification import TextClassification
Expand All @@ -14,6 +15,7 @@
"TextClassification",
"Summarization",
"AutomaticSpeechRecognition",
"ImageClassification",
]

logger = get_logger(__name__)
Expand All @@ -24,6 +26,7 @@
TextClassification.task: TextClassification,
AutomaticSpeechRecognition.task: AutomaticSpeechRecognition,
Summarization.task: Summarization,
ImageClassification.task: ImageClassification,
}


Expand Down
31 changes: 31 additions & 0 deletions src/datasets/tasks/image_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple

from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class ImageClassification(TaskTemplate):
task: str = "image-classification"
input_schema: ClassVar[Features] = Features({"image_file_path": Value("string")})
# TODO(lewtun): Find a more elegant approach without descriptors.
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
image_file_path_column: str = "image_file_path"
label_column: str = "labels"
labels: Optional[Tuple[str]] = None

def __post_init__(self):
if self.labels:
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
self.__dict__["labels"] = tuple(sorted(self.labels))
self.__dict__["label_schema"] = self.label_schema.copy()
self.label_schema["labels"] = ClassLabel(names=self.labels)

@property
def column_mapping(self) -> Dict[str, str]:
return {
self.image_file_path_column: "image_file_path",
self.label_column: "labels",
}
30 changes: 29 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from unittest.case import TestCase

from datasets.features import ClassLabel, Features, Sequence, Value
from datasets.tasks import AutomaticSpeechRecognition, QuestionAnsweringExtractive, Summarization, TextClassification
from datasets.tasks import (
AutomaticSpeechRecognition,
ImageClassification,
QuestionAnsweringExtractive,
Summarization,
TextClassification,
)


class TextClassificationTest(TestCase):
Expand Down Expand Up @@ -90,3 +96,25 @@ def test_from_dict(self):
self.assertEqual("automatic-speech-recognition", task.task)
self.assertEqual(input_schema, task.input_schema)
self.assertEqual(label_schema, task.label_schema)


class ImageClassificationTest(TestCase):
def setUp(self):
self.labels = sorted(["pos", "neg"])

def test_column_mapping(self):
task = ImageClassification(image_file_path_column="file_paths", label_column="input_label")
self.assertDictEqual({"file_paths": "image_file_path", "input_label": "labels"}, task.column_mapping)

def test_from_dict(self):
input_schema = Features({"image_file_path": Value("string")})
label_schema = Features({"labels": ClassLabel(names=tuple(self.labels))})
template_dict = {
"image_file_path_column": "input_image_file_path",
"label_column": "input_label",
"labels": self.labels,
}
task = ImageClassification.from_dict(template_dict)
self.assertEqual("image-classification", task.task)
self.assertEqual(input_schema, task.input_schema)
self.assertEqual(label_schema, task.label_schema)