Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ data:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default
---
settings:
Expand Down
2 changes: 2 additions & 0 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.anthropic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.azure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.bedrock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.deepseek.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.grok.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.groq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.lm_studio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.ollama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.open_router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.qwen3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.zhipu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Settings(BaseSettings):
allow_intent_classification: bool = Field(default=True)
allow_sql_generation_reasoning: bool = Field(default=True)
allow_sql_functions_retrieval: bool = Field(default=True)
allow_sql_diagnosis: bool = Field(default=False)
max_histories: int = Field(default=5)
max_sql_correction_retries: int = Field(default=3)

Expand Down
7 changes: 7 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def create_service_container(
_sql_executor_pipeline = retrieval.SQLExecutor(
**pipe_components["sql_executor"],
)
_sql_diagnosis_pipeline = generation.SQLDiagnosis(
**pipe_components["sql_diagnosis"],
)

return ServiceContainer(
semantics_description=services.SemanticsDescription(
Expand Down Expand Up @@ -146,10 +149,12 @@ def create_service_container(
**pipe_components["followup_sql_generation"],
),
"sql_functions_retrieval": _sql_functions_retrieval_pipeline,
"sql_diagnosis": _sql_diagnosis_pipeline,
},
allow_intent_classification=settings.allow_intent_classification,
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_diagnosis=settings.allow_sql_diagnosis,
max_histories=settings.max_histories,
enable_column_pruning=settings.enable_column_pruning,
max_sql_correction_retries=settings.max_sql_correction_retries,
Expand All @@ -165,8 +170,10 @@ def create_service_container(
**pipe_components["sql_regeneration"],
),
"sql_correction": _sql_correction_pipeline,
"sql_diagnosis": _sql_diagnosis_pipeline,
},
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_diagnosis=settings.allow_sql_diagnosis,
**query_cache,
),
chart_service=services.ChartService(
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/src/pipelines/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .semantics_description import SemanticsDescription
from .sql_answer import SQLAnswer
from .sql_correction import SQLCorrection
from .sql_diagnosis import SQLDiagnosis
from .sql_generation import SQLGeneration
from .sql_generation_reasoning import SQLGenerationReasoning
from .sql_question import SQLQuestion
Expand All @@ -28,6 +29,7 @@
"SemanticsDescription",
"SQLAnswer",
"SQLCorrection",
"SQLDiagnosis",
"SQLGeneration",
"SQLGenerationReasoning",
"UserGuideAssistance",
Expand Down
26 changes: 20 additions & 6 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,29 @@
SQLGenPostProcessor,
construct_instructions,
)
from src.pipelines.retrieval.sql_functions import SqlFunction
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")


sql_correction_system_prompt = f"""
### TASK ###
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills.

Now you are given syntactically incorrect ANSI SQL query and related error message, please generate the syntactically correct ANSI SQL query without changing original semantics.
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills, you need to fix the syntactically incorrect ANSI SQL query.

### SQL CORRECTION INSTRUCTIONS ###

1. Make sure you follow the SQL Rules strictly.
2. Make sure you check the SQL CORRECTION EXAMPLES for reference.
1. First, think hard about the error message, and firgure out the root cause first.
2. Then, generate the syntactically correct ANSI SQL query to correct the error.
3. You could try to use other methods(new functions, etc.) to rewrite the SQL query to correct the error, but you should not change the original semantics of the SQL query.

### SQL RULES ###
Make sure you follow the SQL Rules strictly.

{TEXT_TO_SQL_RULES}

### FINAL ANSWER FORMAT ###
The final answer must be a corrected SQL query in JSON format:
The final answer must be in JSON format:

{{
"sql": <CORRECTED_SQL_QUERY_STRING>
Expand All @@ -52,6 +55,13 @@
{% endfor %}
{% endif %}

{% if sql_functions %}
### SQL FUNCTIONS ###
{% for function in sql_functions %}
{{ function }}
{% endfor %}
{% endif %}

{% if instructions %}
### USER INSTRUCTIONS ###
{% for instruction in instructions %}
Expand All @@ -74,13 +84,15 @@ def prompt(
invalid_generation_result: Dict,
prompt_builder: PromptBuilder,
instructions: list[dict] | None = None,
sql_functions: list[SqlFunction] | None = None,
) -> dict:
_prompt = prompt_builder.run(
documents=documents,
invalid_generation_result=invalid_generation_result,
instructions=construct_instructions(
instructions=instructions,
),
sql_functions=sql_functions,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}

Expand Down Expand Up @@ -148,6 +160,7 @@ async def run(
contexts: List[Document],
invalid_generation_result: Dict[str, str],
instructions: list[dict] | None = None,
sql_functions: list[SqlFunction] | None = None,
project_id: str | None = None,
use_dry_plan: bool = False,
allow_dry_plan_fallback: bool = True,
Expand All @@ -165,6 +178,7 @@ async def run(
"invalid_generation_result": invalid_generation_result,
"documents": contexts,
"instructions": instructions,
"sql_functions": sql_functions,
"project_id": project_id,
"use_dry_plan": use_dry_plan,
"allow_dry_plan_fallback": allow_dry_plan_fallback,
Expand Down
Loading
Loading