@@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
2323 "question" : "In what field is HuggingFace working ?" ,
2424 "context" : "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP." ,
2525 },
26+ {
27+ "question" : ["In what field is HuggingFace working ?" , "In what field is HuggingFace working ?" ],
28+ "context" : "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP." ,
29+ },
30+ {
31+ "question" : ["In what field is HuggingFace working ?" , "In what field is HuggingFace working ?" ],
32+ "context" : [
33+ "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP." ,
34+ "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP." ,
35+ ],
36+ },
2637 ]
2738
2839 def _test_pipeline (self , nlp : Pipeline ):
@@ -80,6 +91,11 @@ def test_argument_handler(self):
8091 self .assertEqual (len (normalized ), 1 )
8192 self .assertEqual ({type (el ) for el in normalized }, {SquadExample })
8293
94+ normalized = qa (question = [Q , Q ], context = C )
95+ self .assertEqual (type (normalized ), list )
96+ self .assertEqual (len (normalized ), 2 )
97+ self .assertEqual ({type (el ) for el in normalized }, {SquadExample })
98+
8399 normalized = qa ({"question" : Q , "context" : C })
84100 self .assertEqual (type (normalized ), list )
85101 self .assertEqual (len (normalized ), 1 )
@@ -159,6 +175,26 @@ def test_argument_handler_error_handling(self):
159175 with self .assertRaises (ValueError ):
160176 qa ([{"question" : Q , "context" : C }, {"question" : Q , "context" : "" }])
161177
178+ with self .assertRaises (ValueError ):
179+ qa (question = {"This" : "Is weird" }, context = "This is a context" )
180+
181+ with self .assertRaises (ValueError ):
182+ qa (question = [Q , Q ], context = [C , C , C ])
183+
184+ with self .assertRaises (ValueError ):
185+ qa (question = [Q , Q , Q ], context = [C , C ])
186+
187+ def test_argument_handler_old_format (self ):
188+ qa = QuestionAnsweringArgumentHandler ()
189+
190+ Q = "Where was HuggingFace founded ?"
191+ C = "HuggingFace was founded in Paris"
192+ # Backward compatibility for this
193+ normalized = qa (question = [Q , Q ], context = [C , C ])
194+ self .assertEqual (type (normalized ), list )
195+ self .assertEqual (len (normalized ), 2 )
196+ self .assertEqual ({type (el ) for el in normalized }, {SquadExample })
197+
162198 def test_argument_handler_error_handling_odd (self ):
163199 qa = QuestionAnsweringArgumentHandler ()
164200 with self .assertRaises (ValueError ):
0 commit comments