|
5 | 5 | from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value |
6 | 6 | from datasets.info import DatasetInfo |
7 | 7 | from datasets.tasks import ( |
| 8 | + AudioClassification, |
8 | 9 | AutomaticSpeechRecognition, |
9 | 10 | ImageClassification, |
10 | 11 | LanguageModeling, |
@@ -126,6 +127,33 @@ def test_from_dict(self): |
126 | 127 | self.assertEqual(label_schema, task.label_schema) |
127 | 128 |
|
128 | 129 |
|
| 130 | +class AudioClassificationTest(TestCase): |
| 131 | + def setUp(self): |
| 132 | + self.labels = sorted(["pos", "neg"]) |
| 133 | + |
| 134 | + def test_column_mapping(self): |
| 135 | + task = AudioClassification(audio_column="input_audio", label_column="input_label") |
| 136 | + self.assertDictEqual({"input_audio": "audio", "input_label": "labels"}, task.column_mapping) |
| 137 | + |
| 138 | + def test_from_dict(self): |
| 139 | + input_schema = Features({"audio": Audio()}) |
| 140 | + label_schema = Features({"labels": ClassLabel}) |
| 141 | + template_dict = { |
| 142 | + "audio_column": "input_image", |
| 143 | + "label_column": "input_label", |
| 144 | + } |
| 145 | + task = AudioClassification.from_dict(template_dict) |
| 146 | + self.assertEqual("audio-classification", task.task) |
| 147 | + self.assertEqual(input_schema, task.input_schema) |
| 148 | + self.assertEqual(label_schema, task.label_schema) |
| 149 | + |
| 150 | + def test_align_with_features(self): |
| 151 | + task = AudioClassification(audio_column="input_audio", label_column="input_label") |
| 152 | + self.assertEqual(task.label_schema["labels"], ClassLabel) |
| 153 | + task = task.align_with_features(Features({"input_label": ClassLabel(names=self.labels)})) |
| 154 | + self.assertEqual(task.label_schema["labels"], ClassLabel(names=self.labels)) |
| 155 | + |
| 156 | + |
129 | 157 | class ImageClassificationTest(TestCase): |
130 | 158 | def setUp(self): |
131 | 159 | self.labels = sorted(["pos", "neg"]) |
|
0 commit comments