Skip to content

Commit 1274aa5

Browse files
committed
Add tests for invalid templates
1 parent 56635e7 commit 1274aa5

File tree

1 file changed

+55
-6
lines changed

1 file changed

+55
-6
lines changed

tests/test_arrow_dataset.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,18 +1939,26 @@ def test_task_text_classification(self, in_memory):
19391939
"labels": ClassLabel(names=tuple(labels)),
19401940
}
19411941
)
1942+
task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels)
19421943
info = DatasetInfo(
19431944
features=features_before_cast,
1944-
task_templates=TextClassification(text_column="input_text", label_column="input_labels", labels=labels),
1945+
task_templates=task,
19451946
)
19461947
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
1948+
# Test we can load from task name
19471949
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
19481950
with self._to(in_memory, tmp_dir, dset) as dset:
19491951
self.assertSetEqual(set(["input_text", "input_labels"]), set(dset.column_names))
19501952
self.assertDictEqual(features_before_cast, dset.features)
19511953
with dset.prepare_for_task(task="text-classification") as dset:
19521954
self.assertSetEqual(set(["labels", "text"]), set(dset.column_names))
19531955
self.assertDictEqual(features_after_cast, dset.features)
1956+
# Test we can load from TaskTemplate
1957+
info.task_templates = None
1958+
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
1959+
with dset.prepare_for_task(task=task) as dset:
1960+
self.assertSetEqual(set(["labels", "text"]), set(dset.column_names))
1961+
self.assertDictEqual(features_after_cast, dset.features)
19541962

19551963
def test_task_question_answering(self, in_memory):
19561964
features_before_cast = Features(
@@ -1977,17 +1985,16 @@ def test_task_question_answering(self, in_memory):
19771985
),
19781986
}
19791987
)
1980-
info = DatasetInfo(
1981-
features=features_before_cast,
1982-
task_templates=QuestionAnswering(
1983-
context_column="input_context", question_column="input_question", answers_column="input_answers"
1984-
),
1988+
task = QuestionAnswering(
1989+
context_column="input_context", question_column="input_question", answers_column="input_answers"
19851990
)
1991+
info = DatasetInfo(features=features_before_cast, task_templates=task)
19861992
data = {
19871993
"input_context": ["huggingface is going to the moon!"],
19881994
"input_question": ["where is huggingface going?"],
19891995
"input_answers": [{"text": ["to the moon!"], "answer_start": [2]}],
19901996
}
1997+
# Test we can load from task name
19911998
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
19921999
with self._to(in_memory, tmp_dir, dset) as dset:
19932000
self.assertSetEqual(
@@ -2001,6 +2008,48 @@ def test_task_question_answering(self, in_memory):
20012008
set(dset.flatten().column_names),
20022009
)
20032010
self.assertDictEqual(features_after_cast, dset.features)
2011+
# Test we can load from TaskTemplate
2012+
info.task_templates = None
2013+
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
2014+
with dset.prepare_for_task(task=task) as dset:
2015+
self.assertSetEqual(
2016+
set(["context", "question", "answers.text", "answers.answer_start"]),
2017+
set(dset.flatten().column_names),
2018+
)
2019+
self.assertDictEqual(features_after_cast, dset.features)
2020+
2021+
def test_task_with_no_template(self, in_memory):
2022+
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
2023+
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data) as dset:
2024+
with self._to(in_memory, tmp_dir, dset) as dset:
2025+
with self.assertRaises(ValueError):
2026+
dset.prepare_for_task("text-classification")
2027+
2028+
def test_task_with_incompatible_templates(self, in_memory):
2029+
labels = sorted(["pos", "neg"])
2030+
features = Features(
2031+
{
2032+
"input_text": Value("string"),
2033+
"input_labels": Value("int32"),
2034+
}
2035+
)
2036+
task = TextClassification(text_column="input_text", label_column="input_labels", labels=labels)
2037+
info = DatasetInfo(
2038+
features=features,
2039+
task_templates=task,
2040+
)
2041+
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
2042+
with tempfile.TemporaryDirectory() as tmp_dir, Dataset.from_dict(data, info=info) as dset:
2043+
with self._to(in_memory, tmp_dir, dset) as dset:
2044+
with self._to(in_memory, tmp_dir, dset) as dset:
2045+
with self.assertRaises(ValueError):
2046+
# Invalid task name
2047+
dset.prepare_for_task("this-task-does-not-exist")
2048+
# Duplicate task templates
2049+
dset.info.task_templates = [task, task]
2050+
dset.prepare_for_task("text-classification")
2051+
# Invalid task type
2052+
dset.prepare_for_task(1)
20042053

20052054

20062055
class MiscellaneousDatasetTest(TestCase):

0 commit comments

Comments
 (0)