Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions wren-ai-service/src/pipelines/generation/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ async def _classify_generation_result(
"correlation_id": "",
}
elif use_dry_run:
status, _, addition = await self._engine.execute_sql(
has_data, _, addition = await self._engine.execute_sql(
quoted_sql,
session,
project_id=project_id,
limit=1,
dry_run=True,
)

if status:
if has_data:
valid_generation_result = {
"sql": quoted_sql,
"correlation_id": addition.get("correlation_id", ""),
Expand All @@ -132,15 +132,15 @@ async def _classify_generation_result(
"correlation_id": addition.get("correlation_id", ""),
}
else:
status, _, addition = await self._engine.execute_sql(
has_data, _, addition = await self._engine.execute_sql(
quoted_sql,
session,
project_id=project_id,
limit=1,
dry_run=False,
)

if status:
if has_data:
valid_generation_result = {
"sql": quoted_sql,
"correlation_id": addition.get("correlation_id", ""),
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/pipelines/retrieval/sql_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ async def run(
limit: int = 500,
):
async with aiohttp.ClientSession() as session:
_, data, _ = await self._engine.execute_sql(
_, data, addition = await self._engine.execute_sql(
sql,
session,
project_id=project_id,
dry_run=False,
limit=limit,
)

if addition.get("error_message"):
return {"results": data, "error_message": addition.get("error_message")}
return {"results": data}


Expand Down
20 changes: 18 additions & 2 deletions wren-ai-service/src/web/v1/services/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,31 @@ async def chart(
trace_id=trace_id,
)

sql_data = (
execute_sql_result = (
await self._pipelines["sql_executor"].run(
sql=chart_request.sql,
project_id=chart_request.project_id,
)
)["execute_sql"]["results"]
)["execute_sql"]

sql_data = execute_sql_result["results"]
error_message = execute_sql_result.get("error_message", None)
else:
sql_data = chart_request.data

if error_message:
self._chart_results[query_id] = ChartResultResponse(
status="failed",
error=ChartError(
code="OTHERS",
message=error_message,
),
trace_id=trace_id,
)
results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = error_message
return results

self._chart_results[query_id] = ChartResultResponse(
status="generating",
trace_id=trace_id,
Expand Down
21 changes: 19 additions & 2 deletions wren-ai-service/src/web/v1/services/chart_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,29 @@ async def chart_adjustment(
trace_id=trace_id,
)

sql_data = (
execute_sql_result = (
await self._pipelines["sql_executor"].run(
sql=chart_adjustment_request.sql,
project_id=chart_adjustment_request.project_id,
)
)["execute_sql"]["results"]
)["execute_sql"]

sql_data = execute_sql_result["results"]
error_message = execute_sql_result.get("error_message", None)

if error_message:
self._chart_adjustment_results[
query_id
] = ChartAdjustmentResultResponse(
status="failed",
error=ChartAdjustmentError(
code="OTHERS",
message=error_message,
),
)
results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = error_message
return results

self._chart_adjustment_results[query_id] = ChartAdjustmentResultResponse(
status="generating",
Expand Down
Loading