From 12eba7857789d4b1b32becfd03f41d71ca0addf0 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 23 Jan 2025 22:04:14 +0800 Subject: [PATCH 1/4] wip --- wren-ai-service/src/web/v1/services/ask.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index a6f2af29bd..03aacfe94b 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -92,6 +92,7 @@ class AskResultResponse(BaseModel): ] rephrased_question: Optional[str] = None intent_reasoning: Optional[str] = None + generation_reasoning: Optional[str] = None type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -285,12 +286,21 @@ async def ask( .get("reasoning_plan") ) + self._ask_results[query_id] = AskResultResponse( + status="planning", + type="TEXT_TO_SQL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, + ) + if not self._is_stopped(query_id) and not api_results: self._ask_results[query_id] = AskResultResponse( status="generating", type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, ) sql_samples = ( @@ -358,6 +368,7 @@ async def ask( contexts=documents, invalid_generation_results=failed_dry_run_results, project_id=ask_request.project_id, + generation_reasoning=sql_generation_reasoning, ) if valid_generation_results := sql_correction_results[ @@ -381,6 +392,7 @@ async def ask( response=api_results, rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, ) results["ask_result"] = api_results results["metadata"]["type"] = "TEXT_TO_SQL" @@ -396,6 +408,7 @@ async def ask( ), rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" results["metadata"]["type"] = "TEXT_TO_SQL" From e01c8b54f040794d1bb5f3f83996142a57c27bd9 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Fri, 24 Jan 2025 07:24:07 +0800 Subject: [PATCH 2/4] output generation reasoning and refine usecase test output --- wren-ai-service/src/web/v1/services/ask.py | 2 -- wren-ai-service/tests/pytest/test_usecases.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 03aacfe94b..03632c85d2 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -424,8 +424,6 @@ async def ask( code="OTHERS", message=str(e), ), - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, ) results["metadata"]["error_type"] = "OTHERS" diff --git a/wren-ai-service/tests/pytest/test_usecases.py b/wren-ai-service/tests/pytest/test_usecases.py index 9436d8276b..34b6d336be 100644 --- a/wren-ai-service/tests/pytest/test_usecases.py +++ b/wren-ai-service/tests/pytest/test_usecases.py @@ -212,7 +212,7 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id os.makedirs("outputs") with open( - f"outputs/final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json", + f"outputs/usecases_final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml", "w", ) as f: - json.dump(final_results, f, indent=2) + yaml.safe_dump(final_results, f, sort_keys=False, allow_unicode=True) From dd684843fbbd016324d45e05731dd6f164274b58 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Fri, 24 Jan 2025 07:47:25 +0800 Subject: [PATCH 3/4] allow multiline string output --- wren-ai-service/tests/pytest/test_usecases.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/wren-ai-service/tests/pytest/test_usecases.py b/wren-ai-service/tests/pytest/test_usecases.py index 34b6d336be..8ed1c46554 100644 --- a/wren-ai-service/tests/pytest/test_usecases.py +++ b/wren-ai-service/tests/pytest/test_usecases.py @@ -10,6 +10,7 @@ import aiohttp import orjson import requests +import sqlparse import yaml from demo.utils import ( @@ -144,6 +145,14 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id return await asyncio.gather(*tasks) +def str_presenter(dumper, data): + """configures yaml for dumping multiline strings + Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data""" + if len(data.splitlines()) > 1: # check for multiline string + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + if __name__ == "__main__": usecase_to_dataset_type = { "hubspot": "bigquery", @@ -195,6 +204,17 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id } # count the number of results that are failed for question, result in zip(data["questions"], results): + if ( + result.get("status") == "finished" + and not result.get("error") + and result.get("response", []) + ): + result["response"][0]["sql"] = sqlparse.format( + result["response"][0]["sql"], + reindent=True, + keyword_case="upper", + ) + final_results[usecase]["results"].append( { "question": question, @@ -215,4 +235,8 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id f"outputs/usecases_final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml", "w", ) as f: + yaml.add_representer(str, str_presenter) + yaml.representer.SafeRepresenter.add_representer( + str, str_presenter + ) # to use with safe_dum yaml.safe_dump(final_results, f, sort_keys=False, allow_unicode=True) From a50396da4735049feb877896c4f8f6943a207b60 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Fri, 24 Jan 2025 08:13:25 +0800 Subject: [PATCH 4/4] fix bug --- wren-ai-service/src/web/v1/services/ask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 03632c85d2..a762959758 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -368,7 +368,6 @@ async def ask( contexts=documents, invalid_generation_results=failed_dry_run_results, project_id=ask_request.project_id, - generation_reasoning=sql_generation_reasoning, ) if valid_generation_results := sql_correction_results[