diff --git a/qanta/ingestion/pipeline.py b/qanta/ingestion/pipeline.py index b36658b8..a46bb357 100644 --- a/qanta/ingestion/pipeline.py +++ b/qanta/ingestion/pipeline.py @@ -36,6 +36,7 @@ from qanta.ingestion.preprocess import ( format_qanta_json, add_sentences_, + add_answer_prompts_, questions_to_sqlite, ) from qanta.ingestion.protobowl import compute_question_player_counts @@ -160,6 +161,7 @@ def run(self): with open(QANTA_UNMAPPED_DATASET_PATH) as f: qanta_questions = json.load(f)["questions"] add_sentences_(qanta_questions) + add_answer_prompts_(qanta_questions) with open(QANTA_PREPROCESSED_DATASET_PATH, "w") as f: json.dump(format_qanta_json(qanta_questions, DS_VERSION), f) diff --git a/qanta/ingestion/preprocess.py b/qanta/ingestion/preprocess.py index 5f679a07..d0bc5e20 100644 --- a/qanta/ingestion/preprocess.py +++ b/qanta/ingestion/preprocess.py @@ -2,6 +2,7 @@ import spacy import unidecode import ftfy +import re from qanta import qlogging from qanta.spark import create_spark_context @@ -79,6 +80,38 @@ def add_sentences_(questions, parallel=True): # Get the 0th sentence, end character tokenization (tuple position 1) q["first_sentence"] = text[: tokenization[0][1]] +def extract_prompt(ans): + l_ans = ans.lower() + if "accept" in l_ans or "prompt" in l_ans or "pronounce" in l_ans: + m = re.match( + r"(.+)\((.*(accept|prompt|pronounce).*)\)", ans, flags=re.IGNORECASE + ) + if m is not None: + return m.group(2).strip() + + m = re.match( + r"(.+)\[(.*(accept|prompt|pronounce).*)\]", ans, flags=re.IGNORECASE + ) + if m is not None: + return m.group(2).strip() + + return "" + elif "or" in l_ans: + m = re.match(r"(.+)\((.*or.*)\)", ans, flags=re.IGNORECASE) + if m is not None: + return m.group(2).strip() + + m = re.match(r"(.+)\[(.*or.*)\]", ans, flags=re.IGNORECASE) + if m is not None: + return m.group(2).strip() + + return "" + else: + return "" + +def add_answer_prompts_(questions): + for q in questions: + q['answer_prompt'] = extract_prompt(q['answer']) def questions_to_sqlite(qanta_questions, db_path): conn = sqlite3.connect(db_path)