Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ data:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.anthropic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.azure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.bedrock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.deepseek.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.grok.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.groq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.lm_studio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.ollama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.open_router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.qwen3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/docs/config_examples/config.zhipu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pipes:
document_store: qdrant
- name: sql_tables_extraction
llm: litellm_llm.default
- name: sql_diagnosis
llm: litellm_llm.default

---
settings:
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Settings(BaseSettings):
allow_intent_classification: bool = Field(default=True)
allow_sql_generation_reasoning: bool = Field(default=True)
allow_sql_functions_retrieval: bool = Field(default=True)
allow_sql_diagnosis: bool = Field(default=False)
max_histories: int = Field(default=5)
max_sql_correction_retries: int = Field(default=3)

Expand Down
7 changes: 7 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def create_service_container(
_sql_executor_pipeline = retrieval.SQLExecutor(
**pipe_components["sql_executor"],
)
_sql_diagnosis_pipeline = generation.SQLDiagnosis(
**pipe_components["sql_diagnosis"],
)

return ServiceContainer(
semantics_description=services.SemanticsDescription(
Expand Down Expand Up @@ -146,10 +149,12 @@ def create_service_container(
**pipe_components["followup_sql_generation"],
),
"sql_functions_retrieval": _sql_functions_retrieval_pipeline,
"sql_diagnosis": _sql_diagnosis_pipeline,
},
allow_intent_classification=settings.allow_intent_classification,
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_diagnosis=settings.allow_sql_diagnosis,
max_histories=settings.max_histories,
enable_column_pruning=settings.enable_column_pruning,
max_sql_correction_retries=settings.max_sql_correction_retries,
Expand All @@ -165,8 +170,10 @@ def create_service_container(
**pipe_components["sql_regeneration"],
),
"sql_correction": _sql_correction_pipeline,
"sql_diagnosis": _sql_diagnosis_pipeline,
},
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_diagnosis=settings.allow_sql_diagnosis,
**query_cache,
),
chart_service=services.ChartService(
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/src/pipelines/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .semantics_description import SemanticsDescription
from .sql_answer import SQLAnswer
from .sql_correction import SQLCorrection
from .sql_diagnosis import SQLDiagnosis
from .sql_generation import SQLGeneration
from .sql_generation_reasoning import SQLGenerationReasoning
from .sql_question import SQLQuestion
Expand All @@ -28,6 +29,7 @@
"SemanticsDescription",
"SQLAnswer",
"SQLCorrection",
"SQLDiagnosis",
"SQLGeneration",
"SQLGenerationReasoning",
"UserGuideAssistance",
Expand Down
48 changes: 40 additions & 8 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,44 @@
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, LLMProvider
from src.pipelines.common import clean_up_new_lines, retrieve_metadata
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
TEXT_TO_SQL_RULES,
SQLGenPostProcessor,
construct_instructions,
)
from src.pipelines.retrieval.sql_functions import SqlFunction
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")


sql_correction_system_prompt = f"""
### TASK ###
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills.

Now you are given syntactically incorrect ANSI SQL query and related error message, please generate the syntactically correct ANSI SQL query without changing original semantics.
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills, you need to fix the syntactically incorrect ANSI SQL query.

### SQL CORRECTION INSTRUCTIONS ###

1. Make sure you follow the SQL Rules strictly.
2. Make sure you check the SQL CORRECTION EXAMPLES for reference.
1. First, think hard about the error message, and firgure out the root cause first.
2. Then, generate the reasoning behind the correction.
3. Finally, generate the syntactically correct ANSI SQL query based on the reasoning to correct the error.
4. You could try to use other methods(new functions, etc.) to rewrite the SQL query to correct the error, but you should not change the original semantics of the SQL query.

### SQL RULES ###
Make sure you follow the SQL Rules strictly.

{TEXT_TO_SQL_RULES}

### FINAL ANSWER FORMAT ###
The final answer must be a corrected SQL query in JSON format:
The final answer must be in JSON format:

{{
"reasoning": <REASONING_STRING>,
"sql": <CORRECTED_SQL_QUERY_STRING>
}}
"""
Expand All @@ -52,6 +57,13 @@
{% endfor %}
{% endif %}

{% if sql_functions %}
### SQL FUNCTIONS ###
{% for function in sql_functions %}
{{ function }}
{% endfor %}
{% endif %}

{% if instructions %}
### USER INSTRUCTIONS ###
{% for instruction in instructions %}
Expand All @@ -74,13 +86,15 @@ def prompt(
invalid_generation_result: Dict,
prompt_builder: PromptBuilder,
instructions: list[dict] | None = None,
sql_functions: list[SqlFunction] | None = None,
) -> dict:
_prompt = prompt_builder.run(
documents=documents,
invalid_generation_result=invalid_generation_result,
instructions=construct_instructions(
instructions=instructions,
),
sql_functions=sql_functions,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}

Expand Down Expand Up @@ -114,6 +128,22 @@ async def post_process(
## End of Pipeline


class SqlCorrectionResult(BaseModel):
reasoning: str
sql: str


SQL_CORRECTION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "sql_correction_result",
"schema": SqlCorrectionResult.model_json_schema(),
},
}
}


class SQLCorrection(BasicPipeline):
def __init__(
self,
Expand All @@ -129,7 +159,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_correction_system_prompt,
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
generation_kwargs=SQL_CORRECTION_MODEL_KWARGS,
),
"generator_name": llm_provider.get_model(),
"prompt_builder": PromptBuilder(
Expand All @@ -148,6 +178,7 @@ async def run(
contexts: List[Document],
invalid_generation_result: Dict[str, str],
instructions: list[dict] | None = None,
sql_functions: list[SqlFunction] | None = None,
project_id: str | None = None,
use_dry_plan: bool = False,
allow_dry_plan_fallback: bool = True,
Expand All @@ -165,6 +196,7 @@ async def run(
"invalid_generation_result": invalid_generation_result,
"documents": contexts,
"instructions": instructions,
"sql_functions": sql_functions,
"project_id": project_id,
"use_dry_plan": use_dry_plan,
"allow_dry_plan_fallback": allow_dry_plan_fallback,
Expand Down
Loading
Loading