diff --git a/dynalab/tasks/qa.py b/dynalab/tasks/qa.py index d0b1b0a..0f05a88 100644 --- a/dynalab/tasks/qa.py +++ b/dynalab/tasks/qa.py @@ -25,6 +25,18 @@ "context": " ".join([str(x) + "_" for x in range(513)]), "question": "Can you handle this length?", }, + { + "uid": str(uuid.uuid4()), + "context": "The sum of 3 and 2 is 5", + "question": "What is the total?", + "answer": "5", + }, + { + "uid": str(uuid.uuid4()), + "context": "The answer to this question is that everyone is doing great today", + "question": "What is the answer?", + "answer": "great today", + }, ] @@ -49,12 +61,16 @@ def verify_response(self, response, data): assert "answer" in response and response["answer"] in data["context"] assert response["signed"] == self.generate_response_signature(response, data) Nk = 3 + if "eval_exact" in response: + Nk += 1 + if "eval_f1" in response: + Nk += 1 if "conf" in response: assert ( response["conf"] >= 0 and response["conf"] <= 1 ), "Confidence score should be between 0 and 1" Nk += 1 - assert Nk == len(response), f"response should not contain other extra keys" + assert Nk >= len(response), f"response should not contain other extra keys" def parse_signature_input(self, response, data): task = "qa" diff --git a/examples/electra_style_qa/handler.py b/examples/electra_style_qa/handler.py index 3c743e2..44258b6 100644 --- a/examples/electra_style_qa/handler.py +++ b/examples/electra_style_qa/handler.py @@ -4,6 +4,7 @@ import torch from transformers import (AutoConfig, AutoTokenizer, AutoModelForQuestionAnswering, QuestionAnsweringPipeline) +from transformers.data.metrics.squad_metrics import compute_exact, compute_f1 from dynalab.handler.base_handler import BaseDynaHandler from dynalab.tasks.qa import TaskIO @@ -32,8 +33,8 @@ def preprocess(self, data): """ example = self._read_data(data) return { - 'context': example['context'], - 'question': example['question'] + 'context': example['context'].strip(), + 'question': example['question'].strip() } def inference(self, input_data): @@ -56,6 +57,12 @@ def postprocess(self, inference_output, data): response["id"] = example["uid"] response["answer"] = answer response["conf"] = conf + + if "answer" in example: + human_answer = str(example["answer"]).strip() + response["eval_f1"] = compute_f1(human_answer, answer) + response["eval_exact"] = compute_exact(human_answer, answer) + response = self.taskIO.sign_response(response, example) return [response]