Skip to content

Commit 138f45c

Browse files
LysandreJikNarsil
andauthored
Fix QA argument handler (#8765)
* Fix QA argument handler * Attempt to get a better fix for QA (#8768) Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
1 parent 4821ea5 commit 138f45c

2 files changed

Lines changed: 47 additions & 1 deletion

File tree

src/transformers/pipelines.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1624,7 +1624,17 @@ def __call__(self, *args, **kwargs):
16241624
elif "data" in kwargs:
16251625
inputs = kwargs["data"]
16261626
elif "question" in kwargs and "context" in kwargs:
1627-
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
1627+
if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str):
1628+
inputs = [{"question": Q, "context": kwargs["context"]} for Q in kwargs["question"]]
1629+
elif isinstance(kwargs["question"], list) and isinstance(kwargs["context"], list):
1630+
if len(kwargs["question"]) != len(kwargs["context"]):
1631+
raise ValueError("Questions and contexts don't have the same lengths")
1632+
1633+
inputs = [{"question": Q, "context": C} for Q, C in zip(kwargs["question"], kwargs["context"])]
1634+
elif isinstance(kwargs["question"], str) and isinstance(kwargs["context"], str):
1635+
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
1636+
else:
1637+
raise ValueError("Arguments can't be understood")
16281638
else:
16291639
raise ValueError("Unknown arguments {}".format(kwargs))
16301640

tests/test_pipelines_question_answering.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
2323
"question": "In what field is HuggingFace working ?",
2424
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
2525
},
26+
{
27+
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
28+
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
29+
},
30+
{
31+
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
32+
"context": [
33+
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
34+
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
35+
],
36+
},
2637
]
2738

2839
def _test_pipeline(self, nlp: Pipeline):
@@ -80,6 +91,11 @@ def test_argument_handler(self):
8091
self.assertEqual(len(normalized), 1)
8192
self.assertEqual({type(el) for el in normalized}, {SquadExample})
8293

94+
normalized = qa(question=[Q, Q], context=C)
95+
self.assertEqual(type(normalized), list)
96+
self.assertEqual(len(normalized), 2)
97+
self.assertEqual({type(el) for el in normalized}, {SquadExample})
98+
8399
normalized = qa({"question": Q, "context": C})
84100
self.assertEqual(type(normalized), list)
85101
self.assertEqual(len(normalized), 1)
@@ -159,6 +175,26 @@ def test_argument_handler_error_handling(self):
159175
with self.assertRaises(ValueError):
160176
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
161177

178+
with self.assertRaises(ValueError):
179+
qa(question={"This": "Is weird"}, context="This is a context")
180+
181+
with self.assertRaises(ValueError):
182+
qa(question=[Q, Q], context=[C, C, C])
183+
184+
with self.assertRaises(ValueError):
185+
qa(question=[Q, Q, Q], context=[C, C])
186+
187+
def test_argument_handler_old_format(self):
188+
qa = QuestionAnsweringArgumentHandler()
189+
190+
Q = "Where was HuggingFace founded ?"
191+
C = "HuggingFace was founded in Paris"
192+
# Backward compatibility for this
193+
normalized = qa(question=[Q, Q], context=[C, C])
194+
self.assertEqual(type(normalized), list)
195+
self.assertEqual(len(normalized), 2)
196+
self.assertEqual({type(el) for el in normalized}, {SquadExample})
197+
162198
def test_argument_handler_error_handling_odd(self):
163199
qa = QuestionAnsweringArgumentHandler()
164200
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)