@@ -1939,18 +1939,26 @@ def test_task_text_classification(self, in_memory):
19391939 "labels" : ClassLabel (names = tuple (labels )),
19401940 }
19411941 )
1942+ task = TextClassification (text_column = "input_text" , label_column = "input_labels" , labels = labels )
19421943 info = DatasetInfo (
19431944 features = features_before_cast ,
1944- task_templates = TextClassification ( text_column = "input_text" , label_column = "input_labels" , labels = labels ) ,
1945+ task_templates = task ,
19451946 )
19461947 data = {"input_text" : ["i love transformers!" ], "input_labels" : [1 ]}
1948+ # Test we can load from task name
19471949 with tempfile .TemporaryDirectory () as tmp_dir , Dataset .from_dict (data , info = info ) as dset :
19481950 with self ._to (in_memory , tmp_dir , dset ) as dset :
19491951 self .assertSetEqual (set (["input_text" , "input_labels" ]), set (dset .column_names ))
19501952 self .assertDictEqual (features_before_cast , dset .features )
19511953 with dset .prepare_for_task (task = "text-classification" ) as dset :
19521954 self .assertSetEqual (set (["labels" , "text" ]), set (dset .column_names ))
19531955 self .assertDictEqual (features_after_cast , dset .features )
1956+ # Test we can load from TaskTemplate
1957+ info .task_templates = None
1958+ with tempfile .TemporaryDirectory () as tmp_dir , Dataset .from_dict (data , info = info ) as dset :
1959+ with dset .prepare_for_task (task = task ) as dset :
1960+ self .assertSetEqual (set (["labels" , "text" ]), set (dset .column_names ))
1961+ self .assertDictEqual (features_after_cast , dset .features )
19541962
19551963 def test_task_question_answering (self , in_memory ):
19561964 features_before_cast = Features (
@@ -1977,17 +1985,16 @@ def test_task_question_answering(self, in_memory):
19771985 ),
19781986 }
19791987 )
1980- info = DatasetInfo (
1981- features = features_before_cast ,
1982- task_templates = QuestionAnswering (
1983- context_column = "input_context" , question_column = "input_question" , answers_column = "input_answers"
1984- ),
1988+ task = QuestionAnswering (
1989+ context_column = "input_context" , question_column = "input_question" , answers_column = "input_answers"
19851990 )
1991+ info = DatasetInfo (features = features_before_cast , task_templates = task )
19861992 data = {
19871993 "input_context" : ["huggingface is going to the moon!" ],
19881994 "input_question" : ["where is huggingface going?" ],
19891995 "input_answers" : [{"text" : ["to the moon!" ], "answer_start" : [2 ]}],
19901996 }
1997+ # Test we can load from task name
19911998 with tempfile .TemporaryDirectory () as tmp_dir , Dataset .from_dict (data , info = info ) as dset :
19921999 with self ._to (in_memory , tmp_dir , dset ) as dset :
19932000 self .assertSetEqual (
@@ -2001,6 +2008,48 @@ def test_task_question_answering(self, in_memory):
20012008 set (dset .flatten ().column_names ),
20022009 )
20032010 self .assertDictEqual (features_after_cast , dset .features )
2011+ # Test we can load from TaskTemplate
2012+ info .task_templates = None
2013+ with tempfile .TemporaryDirectory () as tmp_dir , Dataset .from_dict (data , info = info ) as dset :
2014+ with dset .prepare_for_task (task = task ) as dset :
2015+ self .assertSetEqual (
2016+ set (["context" , "question" , "answers.text" , "answers.answer_start" ]),
2017+ set (dset .flatten ().column_names ),
2018+ )
2019+ self .assertDictEqual (features_after_cast , dset .features )
2020+
2021+ def test_task_with_no_template (self , in_memory ):
2022+ data = {"input_text" : ["i love transformers!" ], "input_labels" : [1 ]}
2023+ with tempfile .TemporaryDirectory () as tmp_dir , Dataset .from_dict (data ) as dset :
2024+ with self ._to (in_memory , tmp_dir , dset ) as dset :
2025+ with self .assertRaises (ValueError ):
2026+ dset .prepare_for_task ("text-classification" )
2027+
2028+ def test_task_with_incompatible_templates (self , in_memory ):
2029+ labels = sorted (["pos" , "neg" ])
2030+ features = Features (
2031+ {
2032+ "input_text" : Value ("string" ),
2033+ "input_labels" : Value ("int32" ),
2034+ }
2035+ )
2036+ task = TextClassification (text_column = "input_text" , label_column = "input_labels" , labels = labels )
2037+ info = DatasetInfo (
2038+ features = features ,
2039+ task_templates = task ,
2040+ )
2041+ data = {"input_text" : ["i love transformers!" ], "input_labels" : [1 ]}
2042+ with tempfile .TemporaryDirectory () as tmp_dir , Dataset .from_dict (data , info = info ) as dset :
2043+ with self ._to (in_memory , tmp_dir , dset ) as dset :
2044+ with self ._to (in_memory , tmp_dir , dset ) as dset :
2045+ with self .assertRaises (ValueError ):
2046+ # Invalid task name
2047+ dset .prepare_for_task ("this-task-does-not-exist" )
2048+ # Duplicate task templates
2049+ dset .info .task_templates = [task , task ]
2050+ dset .prepare_for_task ("text-classification" )
2051+ # Invalid task type
2052+ dset .prepare_for_task (1 )
20042053
20052054
20062055class MiscellaneousDatasetTest (TestCase ):
0 commit comments