Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class AskResultResponse(BaseModel):
]
rephrased_question: Optional[str] = None
intent_reasoning: Optional[str] = None
generation_reasoning: Optional[str] = None
type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None
response: Optional[List[AskResult]] = None
error: Optional[AskError] = None
Expand Down Expand Up @@ -285,12 +286,21 @@ async def ask(
.get("reasoning_plan")
)

self._ask_results[query_id] = AskResultResponse(
status="planning",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
generation_reasoning=sql_generation_reasoning,
)

if not self._is_stopped(query_id) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="generating",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
generation_reasoning=sql_generation_reasoning,
)

sql_samples = (
Expand Down Expand Up @@ -358,6 +368,7 @@ async def ask(
contexts=documents,
invalid_generation_results=failed_dry_run_results,
project_id=ask_request.project_id,
generation_reasoning=sql_generation_reasoning,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Pipeline incompatibility detected: generation_reasoning parameter is not supported

The SQLCorrection pipeline's run method does not accept the generation_reasoning parameter that's being passed. The pipeline needs to be updated to handle this parameter or the parameter should be removed from the caller.

  • wren-ai-service/src/pipelines/generation/sql_correction.py: Update the run method signature to include generation_reasoning parameter if the reasoning context is needed for SQL correction
  • wren-ai-service/src/web/v1/services/ask.py: Remove the parameter if it's not required for SQL correction
🔗 Analysis chain

Verify sql_correction pipeline compatibility.

The generation_reasoning parameter has been added to the SQL correction pipeline. Ensure that the pipeline is updated to handle this new parameter.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if the sql_correction pipeline accepts and uses the generation_reasoning parameter
ast-grep --pattern 'class $_Pipeline:
  $$$
  def run(self, $$$, generation_reasoning: $_Type, $$$):
    $$$
'

Length of output: 112


Script:

#!/bin/bash
# Search for SQL correction pipeline class and related files
rg -l "sql.*correction" --type py

# Then look for the specific implementation
rg "class.*(?i)sql.*correction" --type py

Length of output: 788


Script:

#!/bin/bash
# Check the SQLCorrection class implementation
cat wren-ai-service/src/pipelines/generation/sql_correction.py

Length of output: 4232

)

if valid_generation_results := sql_correction_results[
Expand All @@ -381,6 +392,7 @@ async def ask(
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
generation_reasoning=sql_generation_reasoning,
)
results["ask_result"] = api_results
results["metadata"]["type"] = "TEXT_TO_SQL"
Expand All @@ -396,6 +408,7 @@ async def ask(
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
generation_reasoning=sql_generation_reasoning,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["type"] = "TEXT_TO_SQL"
Expand All @@ -411,8 +424,6 @@ async def ask(
code="OTHERS",
message=str(e),
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)

results["metadata"]["error_type"] = "OTHERS"
Expand Down
28 changes: 26 additions & 2 deletions wren-ai-service/tests/pytest/test_usecases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import aiohttp
import orjson
import requests
import sqlparse
import yaml

from demo.utils import (
Expand Down Expand Up @@ -144,6 +145,14 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id
return await asyncio.gather(*tasks)


def str_presenter(dumper, data):
"""configures yaml for dumping multiline strings
Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data"""
if len(data.splitlines()) > 1: # check for multiline string
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
return dumper.represent_scalar("tag:yaml.org,2002:str", data)


if __name__ == "__main__":
usecase_to_dataset_type = {
"hubspot": "bigquery",
Expand Down Expand Up @@ -195,6 +204,17 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id
}
# count the number of results that are failed
for question, result in zip(data["questions"], results):
if (
result.get("status") == "finished"
and not result.get("error")
and result.get("response", [])
):
result["response"][0]["sql"] = sqlparse.format(
result["response"][0]["sql"],
reindent=True,
keyword_case="upper",
)

final_results[usecase]["results"].append(
{
"question": question,
Expand All @@ -212,7 +232,11 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id
os.makedirs("outputs")

with open(
f"outputs/final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json",
f"outputs/usecases_final_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml",
"w",
) as f:
json.dump(final_results, f, indent=2)
yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(
str, str_presenter
) # to use with safe_dum
yaml.safe_dump(final_results, f, sort_keys=False, allow_unicode=True)
Loading