diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index a6f2af29bd..a762959758 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -92,6 +92,7 @@ class AskResultResponse(BaseModel): ] rephrased_question: Optional[str] = None intent_reasoning: Optional[str] = None + generation_reasoning: Optional[str] = None type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -285,12 +286,21 @@ async def ask( .get("reasoning_plan") ) + self._ask_results[query_id] = AskResultResponse( + status="planning", + type="TEXT_TO_SQL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, + ) + if not self._is_stopped(query_id) and not api_results: self._ask_results[query_id] = AskResultResponse( status="generating", type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, ) sql_samples = ( @@ -381,6 +391,7 @@ async def ask( response=api_results, rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, ) results["ask_result"] = api_results results["metadata"]["type"] = "TEXT_TO_SQL" @@ -396,6 +407,7 @@ async def ask( ), rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + generation_reasoning=sql_generation_reasoning, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" results["metadata"]["type"] = "TEXT_TO_SQL" @@ -411,8 +423,6 @@ async def ask( code="OTHERS", message=str(e), ), - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, ) results["metadata"]["error_type"] = "OTHERS" diff --git a/wren-ai-service/tests/pytest/test_usecases.py b/wren-ai-service/tests/pytest/test_usecases.py index 9436d8276b..8ed1c46554 100644 --- a/wren-ai-service/tests/pytest/test_usecases.py +++ b/wren-ai-service/tests/pytest/test_usecases.py @@ -10,6 +10,7 @@ import aiohttp import orjson import requests +import sqlparse import yaml from demo.utils import ( @@ -144,6 +145,14 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id return await asyncio.gather(*tasks) +def str_presenter(dumper, data): + """configures yaml for dumping multiline strings + Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data""" + if len(data.splitlines()) > 1: # check for multiline string + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + if __name__ == "__main__": usecase_to_dataset_type = { "hubspot": "bigquery", @@ -195,6 +204,17 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id } # count the number of results that are failed for question, result in zip(data["questions"], results): + if ( + result.get("status") == "finished" + and not result.get("error") + and result.get("response", []) + ): + result["response"][0]["sql"] = sqlparse.format( + result["response"][0]["sql"], + reindent=True, + keyword_case="upper", + ) + final_results[usecase]["results"].append( { "question": question, @@ -212,7 +232,11 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id os.makedirs("outputs") with open( - f"outputs/final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json", + f"outputs/usecases_final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml", "w", ) as f: - json.dump(final_results, f, indent=2) + yaml.add_representer(str, str_presenter) + yaml.representer.SafeRepresenter.add_representer( + str, str_presenter + ) # to use with safe_dum + yaml.safe_dump(final_results, f, sort_keys=False, allow_unicode=True)