Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9f9a5da
Refactor SQL generation prompts and rules in the AI service. Moved th…
yichieh-lu Nov 17, 2025
3917330
Update environment versions for WREN engine, IBIS server, and WREN UI…
yichieh-lu Nov 17, 2025
23b25c7
Add async method to retrieve SQL knowledge from WREN API
yichieh-lu Nov 17, 2025
c09d4b3
Refactor SQL generation system prompt for improved clarity and consis…
yichieh-lu Nov 17, 2025
22accf1
Update SQL generation reasoning system prompt for enhanced clarity an…
yichieh-lu Nov 17, 2025
96e1fff
Add SQL knowledge retrieval to service container
yichieh-lu Nov 17, 2025
afe2166
Refactor SQL generation components to utilize getter functions for in…
yichieh-lu Nov 17, 2025
7abf4b0
Add SQL knowledge retrieval configuration to settings and service con…
yichieh-lu Nov 17, 2025
b5029fb
Add support for SQL knowledge retrieval in multiple services
yichieh-lu Nov 17, 2025
5f548db
Implement SQL knowledge retrieval pipeline and refactor SQL generatio…
yichieh-lu Nov 17, 2025
25b6bd0
Remove generator initialization from SQL generation classes to stream…
yichieh-lu Nov 17, 2025
a129c8f
Enhance SQL generation and correction pipelines with SqlKnowledge int…
yichieh-lu Nov 17, 2025
fceeb3e
Refactor SQLGeneration class to remove print statement and enhance co…
yichieh-lu Nov 17, 2025
aa1b5cc
update
cyyeh Nov 18, 2025
844a6b5
update
cyyeh Nov 18, 2025
af3ec91
fix
cyyeh Nov 19, 2025
81c1967
update
cyyeh Nov 19, 2025
49b93a7
Merge branch 'main' into feat/wren-ai-service/get-knowledge-from-ibis
cyyeh Dec 10, 2025
4fa6c11
update configs
cyyeh Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Settings(BaseSettings):
allow_sql_generation_reasoning: bool = Field(default=True)
allow_sql_functions_retrieval: bool = Field(default=True)
allow_sql_diagnosis: bool = Field(default=True)
allow_sql_knowledge_retrieval: bool = Field(default=True)
max_histories: int = Field(default=5)
max_sql_correction_retries: int = Field(default=3)

Expand Down
16 changes: 16 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,15 @@ def create_service_container(
),
"sql_functions_retrieval": _sql_functions_retrieval_pipeline,
"sql_diagnosis": _sql_diagnosis_pipeline,
"sql_knowledge_retrieval": retrieval.SqlKnowledges(
**pipe_components["sql_knowledge_retrieval"],
),
},
allow_intent_classification=settings.allow_intent_classification,
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_diagnosis=settings.allow_sql_diagnosis,
allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval,
max_histories=settings.max_histories,
enable_column_pruning=settings.enable_column_pruning,
max_sql_correction_retries=settings.max_sql_correction_retries,
Expand All @@ -171,9 +175,13 @@ def create_service_container(
),
"sql_correction": _sql_correction_pipeline,
"sql_diagnosis": _sql_diagnosis_pipeline,
"sql_knowledge_retrieval": retrieval.SqlKnowledges(
**pipe_components["sql_knowledge_retrieval"],
),
},
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_diagnosis=settings.allow_sql_diagnosis,
allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval,
**query_cache,
),
chart_service=services.ChartService(
Expand Down Expand Up @@ -225,8 +233,12 @@ def create_service_container(
"sql_pairs_retrieval": _sql_pair_retrieval_pipeline,
"instructions_retrieval": _instructions_retrieval_pipeline,
"sql_functions_retrieval": _sql_functions_retrieval_pipeline,
"sql_knowledge_retrieval": retrieval.SqlKnowledges(
**pipe_components["sql_knowledge_retrieval"],
),
},
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval,
**query_cache,
),
sql_pairs_service=services.SqlPairsService(
Expand Down Expand Up @@ -256,7 +268,11 @@ def create_service_container(
),
"db_schema_retrieval": _db_schema_retrieval_pipeline,
"sql_correction": _sql_correction_pipeline,
"sql_knowledge_retrieval": retrieval.SqlKnowledges(
**pipe_components["sql_knowledge_retrieval"],
),
},
allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval,
**query_cache,
),
)
Expand Down
35 changes: 24 additions & 11 deletions wren-ai-service/src/pipelines/generation/followup_sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
SQLGenPostProcessor,
calculated_field_instructions,
construct_ask_history_messages,
construct_instructions,
json_field_instructions,
metric_instructions,
sql_generation_system_prompt,
get_calculated_field_instructions,
get_json_field_instructions,
get_metric_instructions,
get_sql_generation_system_prompt,
)
from src.pipelines.retrieval.sql_functions import SqlFunction
from src.pipelines.retrieval.sql_knowledge import SqlKnowledge
from src.utils import trace_cost
from src.web.v1.services.ask import AskHistory

Expand Down Expand Up @@ -97,6 +98,7 @@ def prompt(
has_metric: bool = False,
has_json_field: bool = False,
sql_functions: list[SqlFunction] | None = None,
sql_knowledge: SqlKnowledge | None = None,
) -> dict:
_prompt = prompt_builder.run(
query=query,
Expand All @@ -106,10 +108,16 @@ def prompt(
instructions=instructions,
),
calculated_field_instructions=(
calculated_field_instructions if has_calculated_field else ""
get_calculated_field_instructions(sql_knowledge)
if has_calculated_field
else ""
),
metric_instructions=(
get_metric_instructions(sql_knowledge) if has_metric else ""
),
json_field_instructions=(
get_json_field_instructions(sql_knowledge) if has_json_field else ""
),
metric_instructions=(metric_instructions if has_metric else ""),
json_field_instructions=(json_field_instructions if has_json_field else ""),
sql_samples=sql_samples,
sql_functions=sql_functions,
)
Expand Down Expand Up @@ -160,11 +168,9 @@ def __init__(
document_store_provider.get_store("project_meta")
)

self._llm_provider = llm_provider

self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
),
"generator_name": llm_provider.get_model(),
"prompt_builder": PromptBuilder(
template=text_to_sql_with_followup_user_prompt_template
Expand Down Expand Up @@ -192,6 +198,7 @@ async def run(
sql_functions: list[SqlFunction] | None = None,
use_dry_plan: bool = False,
allow_dry_plan_fallback: bool = True,
sql_knowledge: SqlKnowledge | None = None,
):
logger.info("Follow-Up SQL Generation pipeline is running...")

Expand All @@ -200,6 +207,11 @@ async def run(
else:
metadata = {}

self._components["generator"] = self._llm_provider.get_generator(
system_prompt=get_sql_generation_system_prompt(sql_knowledge),
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
)

return await self._pipe.execute(
["post_process"],
inputs={
Expand All @@ -217,6 +229,7 @@ async def run(
"use_dry_plan": use_dry_plan,
"allow_dry_plan_fallback": allow_dry_plan_fallback,
"data_source": metadata.get("data_source", "local_file"),
"sql_knowledge": sql_knowledge,
**self._components,
},
)
22 changes: 15 additions & 7 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@
from src.pipelines.common import clean_up_new_lines, retrieve_metadata
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
TEXT_TO_SQL_RULES,
SQLGenPostProcessor,
construct_instructions,
get_text_to_sql_rules,
)
from src.pipelines.retrieval.sql_functions import SqlFunction
from src.pipelines.retrieval.sql_knowledge import SqlKnowledge
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")


sql_correction_system_prompt = f"""
def get_sql_correction_system_prompt(sql_knowledge: SqlKnowledge | None = None) -> str:
text_to_sql_rules = get_text_to_sql_rules(sql_knowledge)

return f"""
### TASK ###
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills, you need to fix the syntactically incorrect ANSI SQL query.

Expand All @@ -36,7 +40,7 @@
### SQL RULES ###
Make sure you follow the SQL Rules strictly.

{TEXT_TO_SQL_RULES}
{text_to_sql_rules}

### FINAL ANSWER FORMAT ###
The final answer must be in JSON format:
Expand All @@ -46,6 +50,7 @@
}}
"""


sql_correction_user_prompt_template = """
{% if documents %}
### DATABASE SCHEMA ###
Expand Down Expand Up @@ -136,12 +141,9 @@ def __init__(
self._retriever = document_store_provider.get_retriever(
document_store_provider.get_store("project_meta")
)
self._llm_provider = llm_provider

self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_correction_system_prompt,
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
),
"generator_name": llm_provider.get_model(),
"prompt_builder": PromptBuilder(
template=sql_correction_user_prompt_template
Expand All @@ -163,9 +165,15 @@ async def run(
project_id: str | None = None,
use_dry_plan: bool = False,
allow_dry_plan_fallback: bool = True,
sql_knowledge: SqlKnowledge | None = None,
):
logger.info("SQLCorrection pipeline is running...")

self._components["generator"] = self._llm_provider.get_generator(
system_prompt=get_sql_correction_system_prompt(sql_knowledge),
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
)

if use_dry_plan:
metadata = await retrieve_metadata(project_id or "", self._retriever)
else:
Expand Down
35 changes: 24 additions & 11 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
SQLGenPostProcessor,
calculated_field_instructions,
construct_instructions,
json_field_instructions,
metric_instructions,
sql_generation_system_prompt,
get_calculated_field_instructions,
get_json_field_instructions,
get_metric_instructions,
get_sql_generation_system_prompt,
)
from src.pipelines.retrieval.sql_functions import SqlFunction
from src.pipelines.retrieval.sql_knowledge import SqlKnowledge
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -93,6 +94,7 @@ def prompt(
has_metric: bool = False,
has_json_field: bool = False,
sql_functions: list[SqlFunction] | None = None,
sql_knowledge: SqlKnowledge | None = None,
) -> dict:
_prompt = prompt_builder.run(
query=query,
Expand All @@ -102,10 +104,16 @@ def prompt(
instructions=instructions,
),
calculated_field_instructions=(
calculated_field_instructions if has_calculated_field else ""
get_calculated_field_instructions(sql_knowledge)
if has_calculated_field
else ""
),
metric_instructions=(
get_metric_instructions(sql_knowledge) if has_metric else ""
),
json_field_instructions=(
get_json_field_instructions(sql_knowledge) if has_json_field else ""
),
metric_instructions=(metric_instructions if has_metric else ""),
json_field_instructions=(json_field_instructions if has_json_field else ""),
sql_samples=sql_samples,
sql_functions=sql_functions,
)
Expand Down Expand Up @@ -157,11 +165,9 @@ def __init__(
document_store_provider.get_store("project_meta")
)

self._llm_provider = llm_provider

self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
),
"generator_name": llm_provider.get_model(),
"prompt_builder": PromptBuilder(
template=sql_generation_user_prompt_template
Expand Down Expand Up @@ -189,6 +195,7 @@ async def run(
use_dry_plan: bool = False,
allow_dry_plan_fallback: bool = True,
allow_data_preview: bool = False,
sql_knowledge: SqlKnowledge | None = None,
):
logger.info("SQL Generation pipeline is running...")

Expand All @@ -197,6 +204,11 @@ async def run(
else:
metadata = {}

self._components["generator"] = self._llm_provider.get_generator(
system_prompt=get_sql_generation_system_prompt(sql_knowledge),
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
)

return await self._pipe.execute(
["post_process"],
inputs={
Expand All @@ -214,6 +226,7 @@ async def run(
"allow_dry_plan_fallback": allow_dry_plan_fallback,
"data_source": metadata.get("data_source", "local_file"),
"allow_data_preview": allow_data_preview,
"sql_knowledge": sql_knowledge,
**self._components,
},
)
Loading
Loading