Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions wren-ai-service/src/pipelines/generation/sql_diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,14 @@
1. First, think hard about the error message, and analyze the invalid SQL query to figure out the root cause and which part is incorrect.
2. Then, map the incorrect part of the invalid SQL query to the corresponding part of the original SQL query.
3. Then, return the reasoning behind the diagnosis.(You should give me the part of the original SQL query that is incorrect and the reason why it is incorrect)
4. Also, return a boolean value to indicate whether the issue can be corrected with the new SQL query.
5. Reasoning should be in the language same as the language user provided in the INPUTS section.
6. Reasoning should be concise and to the point and within 50 words.
4. Reasoning should be in the language same as the language user provided in the INPUTS section.
5. Reasoning should be concise and to the point and within 50 words.

### FINAL ANSWER FORMAT ###
The final answer must be in JSON format:

{
"reasoning": <REASONING_STRING>,
"can_be_corrected": <CAN_BE_CORRECTED_BOOLEAN>
"reasoning": <REASONING_STRING>
}
"""

Expand Down Expand Up @@ -102,7 +100,6 @@ async def post_process(

class SqlDiagnosisResult(BaseModel):
reasoning: str
can_be_corrected: bool


SQL_DIAGNOSIS_MODEL_KWARGS = {
Expand Down
5 changes: 0 additions & 5 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,6 @@ async def ask(
sql_diagnosis_reasoning = sql_diagnosis_results[
"post_process"
].get("reasoning")
can_be_corrected = sql_diagnosis_results[
"post_process"
].get("can_be_corrected")
if not can_be_corrected:
break

sql_correction_results = await self._pipelines[
"sql_correction"
Expand Down
65 changes: 30 additions & 35 deletions wren-ai-service/src/web/v1/services/ask_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ async def ask_feedback(
trace_id=trace_id,
)

can_be_corrected = True
if allow_sql_diagnosis:
sql_diagnosis_results = await self._pipelines[
"sql_diagnosis"
Expand All @@ -215,42 +214,38 @@ async def ask_feedback(
sql_diagnosis_reasoning = sql_diagnosis_results[
"post_process"
].get("reasoning")
can_be_corrected = sql_diagnosis_results[
"post_process"
].get("can_be_corrected")

if can_be_corrected:
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=table_ddls,
instructions=instructions,
invalid_generation_result={
"sql": original_sql,
"error": sql_diagnosis_reasoning
if allow_sql_diagnosis
else error_message,
},
project_id=ask_feedback_request.project_id,
sql_functions=sql_functions,
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=table_ddls,
instructions=instructions,
invalid_generation_result={
"sql": original_sql,
"error": sql_diagnosis_reasoning
if allow_sql_diagnosis
else error_message,
},
project_id=ask_feedback_request.project_id,
sql_functions=sql_functions,
)

if valid_generation_result := sql_correction_results[
"post_process"
]["valid_generation_result"]:
api_results = [
AskResult(
**{
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
]
elif failed_dry_run_result := sql_correction_results[
"post_process"
]["invalid_generation_result"]:
invalid_sql = failed_dry_run_result["sql"]
error_message = failed_dry_run_result["error"]
if valid_generation_result := sql_correction_results[
"post_process"
]["valid_generation_result"]:
api_results = [
AskResult(
**{
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
]
elif failed_dry_run_result := sql_correction_results[
"post_process"
]["invalid_generation_result"]:
invalid_sql = failed_dry_run_result["sql"]
error_message = failed_dry_run_result["error"]
else:
invalid_sql = failed_dry_run_result["sql"]
error_message = failed_dry_run_result["error"]
Expand Down
Loading