From 9f9a5da79f69a889785e65ae7a8c5ad1a67fdef8 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 14:33:00 +0800 Subject: [PATCH 01/18] Refactor SQL generation prompts and rules in the AI service. Moved the SQL generation system prompt to a new location and updated its format. Added TEXT_TO_SQL_RULES for better clarity in SQL query generation. Cleaned up the construct_instructions function for improved readability. --- .../pipelines/generation/sql_generation.py | 25 +++++++++- .../src/pipelines/generation/utils/sql.py | 46 +++++-------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 27d3bc5eab..4793e4d3e7 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -13,18 +13,41 @@ from src.pipelines.common import clean_up_new_lines, retrieve_metadata from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, + TEXT_TO_SQL_RULES, SQLGenPostProcessor, calculated_field_instructions, construct_instructions, json_field_instructions, metric_instructions, - sql_generation_system_prompt, ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost logger = logging.getLogger("wren-ai-service") +sql_generation_system_prompt = f""" +You are a helpful assistant that converts natural language queries into ANSI SQL queries. + +Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. + +### GENERAL RULES ### + +1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. +2. YOU MUST ONLY CHOOSE the appropriate functions from the sql functions list and use them in the SQL query if the section of SQL FUNCTIONS is available in user's input. +3. YOU MUST REFER to the sql samples and learn the usage of the schema structures and how SQL is written based on them if the section of SQL SAMPLES is available in user's input. +4. YOU MUST FOLLOW the reasoning plan step by step strictly to generate the SQL query if the section of REASONING PLAN is available in user's input. +5. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. + +{TEXT_TO_SQL_RULES} + +### FINAL ANSWER FORMAT ### +The final answer must be a ANSI SQL query in JSON format: + +{{ + "sql": +}} +""" + sql_generation_user_prompt_template = """ ### DATABASE SCHEMA ### diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index 88dc448a0f..af91d95b8d 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -248,28 +248,6 @@ async def _classify_generation_result( - For the ranking problem, you must add the ranking column to the final SELECT clause. """ -sql_generation_system_prompt = f""" -You are a helpful assistant that converts natural language queries into ANSI SQL queries. - -Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. - -### GENERAL RULES ### - -1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. -2. YOU MUST ONLY CHOOSE the appropriate functions from the sql functions list and use them in the SQL query if the section of SQL FUNCTIONS is available in user's input. -3. YOU MUST REFER to the sql samples and learn the usage of the schema structures and how SQL is written based on them if the section of SQL SAMPLES is available in user's input. -4. YOU MUST FOLLOW the reasoning plan step by step strictly to generate the SQL query if the section of REASONING PLAN is available in user's input. -5. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. - -{TEXT_TO_SQL_RULES} - -### FINAL ANSWER FORMAT ### -The final answer must be a ANSI SQL query in JSON format: - -{{ - "sql": -}} -""" calculated_field_instructions = """ #### Instructions for Calculated Field #### @@ -472,18 +450,6 @@ async def _classify_generation_result( """ -def construct_instructions( - instructions: list[dict] | None = None, -): - _instructions = [] - if instructions: - _instructions += [ - instruction.get("instruction") for instruction in instructions - ] - - return _instructions - - class SqlGenerationResult(BaseModel): sql: str @@ -499,6 +465,18 @@ class SqlGenerationResult(BaseModel): } +def construct_instructions( + instructions: list[dict] | None = None, +): + _instructions = [] + if instructions: + _instructions += [ + instruction.get("instruction") for instruction in instructions + ] + + return _instructions + + def construct_ask_history_messages( histories: list[AskHistory] | list[dict], ) -> list[ChatMessage]: From 39173306942cd8a5faba789fa089a44a4f2e9aa0 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 14:33:27 +0800 Subject: [PATCH 02/18] Update environment versions for WREN engine, IBIS server, and WREN UI to 0.21.3, 0.21.3, and 0.31.3 respectively. --- wren-ai-service/tools/dev/.env | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wren-ai-service/tools/dev/.env b/wren-ai-service/tools/dev/.env index 7cc3c1573e..0b6825a77f 100644 --- a/wren-ai-service/tools/dev/.env +++ b/wren-ai-service/tools/dev/.env @@ -11,10 +11,10 @@ IBIS_SERVER_PORT=8000 # version # CHANGE THIS TO THE LATEST VERSION WREN_PRODUCT_VERSION=development -WREN_ENGINE_VERSION=0.20.2 +WREN_ENGINE_VERSION=0.21.3 WREN_AI_SERVICE_VERSION=0.27.14 -IBIS_SERVER_VERSION=0.20.2 -WREN_UI_VERSION=0.31.2 +IBIS_SERVER_VERSION=0.21.3 +WREN_UI_VERSION=0.31.3 WREN_BOOTSTRAP_VERSION=0.1.5 LAUNCH_CLI_PATH=./launch-cli.sh From 23b25c7c869b4cd20c88725bc642dc722543c5bf Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 14:34:51 +0800 Subject: [PATCH 03/18] Add async method to retrieve SQL knowledge from WREN API Implemented the `get_sql_knowledge` method in the WrenIbis class to fetch SQL knowledge from the specified data source. This method handles API requests, manages timeouts, and logs errors appropriately. --- wren-ai-service/src/providers/engine/wren.py | 22 ++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/wren-ai-service/src/providers/engine/wren.py b/wren-ai-service/src/providers/engine/wren.py index 05bc4b0883..3a92853e04 100644 --- a/wren-ai-service/src/providers/engine/wren.py +++ b/wren-ai-service/src/providers/engine/wren.py @@ -263,6 +263,28 @@ async def get_func_list( logger.exception(f"Unexpected error during get_func_list: {str(e)}") return [] + async def get_sql_knowledge( + self, + session: aiohttp.ClientSession, + data_source: str, + timeout: float = settings.engine_timeout, + ) -> Optional[Dict[str, Any]]: + api_endpoint = f"{self._endpoint}/v3/connector/{data_source}/knowledge" + try: + async with session.get(api_endpoint, timeout=timeout) as response: + res = await response.json() + + if response.status != 200: + raise Exception(f"Request failed with message: {res}") + + return res + except asyncio.TimeoutError: + logger.error(f"Request timed out: {timeout} seconds") + return None + except Exception as e: + logger.exception(f"Unexpected error during get_sql_knowledge: {str(e)}") + return None + @provider("wren_engine") class WrenEngine(Engine): From c09d4b3d3d1d7e3d07fd8de8961e954a0fe97290 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 15:01:33 +0800 Subject: [PATCH 04/18] Refactor SQL generation system prompt for improved clarity and consistency. Moved the prompt definition to a centralized location and removed redundant code. This update enhances the structure of SQL query generation in the AI service. --- .../pipelines/generation/sql_generation.py | 25 +------------------ .../src/pipelines/generation/utils/sql.py | 23 +++++++++++++++++ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 4793e4d3e7..27d3bc5eab 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -13,41 +13,18 @@ from src.pipelines.common import clean_up_new_lines, retrieve_metadata from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, - TEXT_TO_SQL_RULES, SQLGenPostProcessor, calculated_field_instructions, construct_instructions, json_field_instructions, metric_instructions, + sql_generation_system_prompt, ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost logger = logging.getLogger("wren-ai-service") -sql_generation_system_prompt = f""" -You are a helpful assistant that converts natural language queries into ANSI SQL queries. - -Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. - -### GENERAL RULES ### - -1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. -2. YOU MUST ONLY CHOOSE the appropriate functions from the sql functions list and use them in the SQL query if the section of SQL FUNCTIONS is available in user's input. -3. YOU MUST REFER to the sql samples and learn the usage of the schema structures and how SQL is written based on them if the section of SQL SAMPLES is available in user's input. -4. YOU MUST FOLLOW the reasoning plan step by step strictly to generate the SQL query if the section of REASONING PLAN is available in user's input. -5. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. - -{TEXT_TO_SQL_RULES} - -### FINAL ANSWER FORMAT ### -The final answer must be a ANSI SQL query in JSON format: - -{{ - "sql": -}} -""" - sql_generation_user_prompt_template = """ ### DATABASE SCHEMA ### diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index af91d95b8d..8296d125f9 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -449,6 +449,29 @@ async def _classify_generation_result( Learn about the usage of the schema structures and generate SQL based on them. """ +sql_generation_system_prompt = f""" +You are a helpful assistant that converts natural language queries into ANSI SQL queries. + +Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. + +### GENERAL RULES ### + +1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. +2. YOU MUST ONLY CHOOSE the appropriate functions from the sql functions list and use them in the SQL query if the section of SQL FUNCTIONS is available in user's input. +3. YOU MUST REFER to the sql samples and learn the usage of the schema structures and how SQL is written based on them if the section of SQL SAMPLES is available in user's input. +4. YOU MUST FOLLOW the reasoning plan step by step strictly to generate the SQL query if the section of REASONING PLAN is available in user's input. +5. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. + +{TEXT_TO_SQL_RULES} + +### FINAL ANSWER FORMAT ### +The final answer must be a ANSI SQL query in JSON format: + +{{ + "sql": +}} +""" + class SqlGenerationResult(BaseModel): sql: str From 22accf10b8c616af8377bf703682455dd19c9b00 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 15:04:15 +0800 Subject: [PATCH 05/18] Update SQL generation reasoning system prompt for enhanced clarity and structure. The prompt has been redefined to improve the step-by-step instructions for data analysts, ensuring consistency in SQL query generation. --- .../src/pipelines/generation/utils/sql.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index 8296d125f9..0f1ddfc93c 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -162,33 +162,6 @@ async def _classify_generation_result( return valid_generation_result, invalid_generation_result -sql_generation_reasoning_system_prompt = """ -### TASK ### -You are a helpful data analyst who is great at thinking deeply and reasoning about the user's question and the database schema, and you provide a step-by-step reasoning plan in order to answer the user's question. - -### INSTRUCTIONS ### -1. Think deeply and reason about the user's question, the database schema, and the user's query history if provided. -2. Explicitly state the following information in the reasoning plan: -if the user puts any specific timeframe(e.g. YYYY-MM-DD) in the user's question(excluding the value of the current time), you will put the absolute time frame in the SQL query; -otherwise, you will put the relative timeframe in the SQL query. -3. For the ranking problem(e.g. "top x", "bottom x", "first x", "last x"), you must use the ranking function, `DENSE_RANK()` to rank the results and then use `WHERE` clause to filter the results. -4. For the ranking problem(e.g. "top x", "bottom x", "first x", "last x"), you must add the ranking column to the final SELECT clause. -5. If USER INSTRUCTIONS section is provided, make sure to consider them in the reasoning plan. -6. If SQL SAMPLES section is provided, make sure to consider them in the reasoning plan. -7. Give a step by step reasoning plan in order to answer user's question. -8. The reasoning plan should be in the language same as the language user provided in the input. -9. Don't include SQL in the reasoning plan. -10. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step. -11. Do not include ```markdown or ``` in the answer. -12. A table name in the reasoning plan must be in this format: `table: `. -13. A column name in the reasoning plan must be in this format: `column: .`. -14. ONLY SHOWING the reasoning plan in bullet points. - -### FINAL ANSWER FORMAT ### -The final answer must be a reasoning plan in plain Markdown string format -""" - - TEXT_TO_SQL_RULES = """ ### SQL RULES ### - ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database. @@ -473,6 +446,33 @@ async def _classify_generation_result( """ +sql_generation_reasoning_system_prompt = """ +### TASK ### +You are a helpful data analyst who is great at thinking deeply and reasoning about the user's question and the database schema, and you provide a step-by-step reasoning plan in order to answer the user's question. + +### INSTRUCTIONS ### +1. Think deeply and reason about the user's question, the database schema, and the user's query history if provided. +2. Explicitly state the following information in the reasoning plan: +if the user puts any specific timeframe(e.g. YYYY-MM-DD) in the user's question(excluding the value of the current time), you will put the absolute time frame in the SQL query; +otherwise, you will put the relative timeframe in the SQL query. +3. For the ranking problem(e.g. "top x", "bottom x", "first x", "last x"), you must use the ranking function, `DENSE_RANK()` to rank the results and then use `WHERE` clause to filter the results. +4. For the ranking problem(e.g. "top x", "bottom x", "first x", "last x"), you must add the ranking column to the final SELECT clause. +5. If USER INSTRUCTIONS section is provided, make sure to consider them in the reasoning plan. +6. If SQL SAMPLES section is provided, make sure to consider them in the reasoning plan. +7. Give a step by step reasoning plan in order to answer user's question. +8. The reasoning plan should be in the language same as the language user provided in the input. +9. Don't include SQL in the reasoning plan. +10. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step. +11. Do not include ```markdown or ``` in the answer. +12. A table name in the reasoning plan must be in this format: `table: `. +13. A column name in the reasoning plan must be in this format: `column: .`. +14. ONLY SHOWING the reasoning plan in bullet points. + +### FINAL ANSWER FORMAT ### +The final answer must be a reasoning plan in plain Markdown string format +""" + + class SqlGenerationResult(BaseModel): sql: str From 96e1fff7c3cb9ad582a278d984fd60c18b2e8f68 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 16:19:37 +0800 Subject: [PATCH 06/18] Add SQL knowledge retrieval to service container Enhanced the service container by integrating the `SqlKnowledges` retrieval component. This addition allows for improved handling of SQL knowledge within the AI service, ensuring better data processing capabilities. --- wren-ai-service/src/globals.py | 3 +++ wren-ai-service/src/pipelines/retrieval/__init__.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index d6c40e05b4..92029c5082 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -150,6 +150,9 @@ def create_service_container( ), "sql_functions_retrieval": _sql_functions_retrieval_pipeline, "sql_diagnosis": _sql_diagnosis_pipeline, + "sql_knowledge_retrieval": retrieval.SqlKnowledges( + **pipe_components["sql_knowledge_retrieval"], + ), }, allow_intent_classification=settings.allow_intent_classification, allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning, diff --git a/wren-ai-service/src/pipelines/retrieval/__init__.py b/wren-ai-service/src/pipelines/retrieval/__init__.py index 01d4429299..aa770f88f5 100644 --- a/wren-ai-service/src/pipelines/retrieval/__init__.py +++ b/wren-ai-service/src/pipelines/retrieval/__init__.py @@ -4,6 +4,7 @@ from .preprocess_sql_data import PreprocessSqlData from .sql_executor import SQLExecutor from .sql_functions import SqlFunctions +from .sql_knowledge import SqlKnowledges from .sql_pairs_retrieval import SqlPairsRetrieval __all__ = [ @@ -14,4 +15,5 @@ "SqlPairsRetrieval", "Instructions", "SqlFunctions", + "SqlKnowledges", ] From afe21663309e3dc6d502ddd93248bf537f5ddc3e Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 16:20:34 +0800 Subject: [PATCH 07/18] Refactor SQL generation components to utilize getter functions for instructions and prompts. This change enhances modularity and maintainability by centralizing the retrieval of SQL-related instructions and prompts across various generation classes. --- .../generation/followup_sql_generation.py | 25 +++- .../pipelines/generation/sql_correction.py | 18 ++- .../pipelines/generation/sql_generation.py | 26 +++- .../pipelines/generation/sql_regeneration.py | 34 +++-- .../src/pipelines/generation/utils/sql.py | 138 ++++++++++++++---- 5 files changed, 184 insertions(+), 57 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 5f889ff00c..3b9493447c 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -14,12 +14,12 @@ from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, SQLGenPostProcessor, - calculated_field_instructions, construct_ask_history_messages, construct_instructions, - json_field_instructions, - metric_instructions, - sql_generation_system_prompt, + get_calculated_field_instructions, + get_json_field_instructions, + get_metric_instructions, + get_sql_generation_system_prompt, ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost @@ -106,10 +106,12 @@ def prompt( instructions=instructions, ), calculated_field_instructions=( - calculated_field_instructions if has_calculated_field else "" + get_calculated_field_instructions() if has_calculated_field else "" + ), + metric_instructions=(get_metric_instructions() if has_metric else ""), + json_field_instructions=( + get_json_field_instructions() if has_json_field else "" ), - metric_instructions=(metric_instructions if has_metric else ""), - json_field_instructions=(json_field_instructions if has_json_field else ""), sql_samples=sql_samples, sql_functions=sql_functions, ) @@ -160,9 +162,11 @@ def __init__( document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider + self._components = { "generator": llm_provider.get_generator( - system_prompt=sql_generation_system_prompt, + system_prompt=get_sql_generation_system_prompt(), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), "generator_name": llm_provider.get_model(), @@ -200,6 +204,11 @@ async def run( else: metadata = {} + self._components["generator"] = self._llm_provider.get_generator( + system_prompt=get_sql_generation_system_prompt(), + generation_kwargs=SQL_GENERATION_MODEL_KWARGS, + ) + return await self._pipe.execute( ["post_process"], inputs={ diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 86b9091f15..e1d6180d66 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -14,9 +14,9 @@ from src.pipelines.common import clean_up_new_lines, retrieve_metadata from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, - TEXT_TO_SQL_RULES, SQLGenPostProcessor, construct_instructions, + get_text_to_sql_rules, ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost @@ -24,7 +24,10 @@ logger = logging.getLogger("wren-ai-service") -sql_correction_system_prompt = f""" +def get_sql_correction_system_prompt() -> str: + text_to_sql_rules = get_text_to_sql_rules() + + return f""" ### TASK ### You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills, you need to fix the syntactically incorrect ANSI SQL query. @@ -36,7 +39,7 @@ ### SQL RULES ### Make sure you follow the SQL Rules strictly. -{TEXT_TO_SQL_RULES} +{text_to_sql_rules} ### FINAL ANSWER FORMAT ### The final answer must be in JSON format: @@ -46,6 +49,7 @@ }} """ + sql_correction_user_prompt_template = """ {% if documents %} ### DATABASE SCHEMA ### @@ -136,10 +140,11 @@ def __init__( self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider self._components = { "generator": llm_provider.get_generator( - system_prompt=sql_correction_system_prompt, + system_prompt=get_sql_correction_system_prompt(), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), "generator_name": llm_provider.get_model(), @@ -166,6 +171,11 @@ async def run( ): logger.info("SQLCorrection pipeline is running...") + self._components["generator"] = self._llm_provider.get_generator( + system_prompt=get_sql_correction_system_prompt(), + generation_kwargs=SQL_GENERATION_MODEL_KWARGS, + ) + if use_dry_plan: metadata = await retrieve_metadata(project_id or "", self._retriever) else: diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 27d3bc5eab..7c09e6dee8 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -14,11 +14,11 @@ from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, SQLGenPostProcessor, - calculated_field_instructions, construct_instructions, - json_field_instructions, - metric_instructions, - sql_generation_system_prompt, + get_calculated_field_instructions, + get_json_field_instructions, + get_metric_instructions, + get_sql_generation_system_prompt, ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost @@ -102,10 +102,12 @@ def prompt( instructions=instructions, ), calculated_field_instructions=( - calculated_field_instructions if has_calculated_field else "" + get_calculated_field_instructions() if has_calculated_field else "" + ), + metric_instructions=(get_metric_instructions() if has_metric else ""), + json_field_instructions=( + get_json_field_instructions() if has_json_field else "" ), - metric_instructions=(metric_instructions if has_metric else ""), - json_field_instructions=(json_field_instructions if has_json_field else ""), sql_samples=sql_samples, sql_functions=sql_functions, ) @@ -157,9 +159,11 @@ def __init__( document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider + self._components = { "generator": llm_provider.get_generator( - system_prompt=sql_generation_system_prompt, + system_prompt=get_sql_generation_system_prompt(), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), "generator_name": llm_provider.get_model(), @@ -168,6 +172,7 @@ def __init__( ), "post_processor": SQLGenPostProcessor(engine=engine), } + print("get_sql_generation_system_prompt", get_sql_generation_system_prompt()) super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) @@ -197,6 +202,11 @@ async def run( else: metadata = {} + self._components["generator"] = self._llm_provider.get_generator( + system_prompt=get_sql_generation_system_prompt(), + generation_kwargs=SQL_GENERATION_MODEL_KWARGS, + ) + return await self._pipe.execute( ["post_process"], inputs={ diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 0dbf2fd808..b6cebb530f 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -13,12 +13,12 @@ from src.pipelines.common import clean_up_new_lines from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, - TEXT_TO_SQL_RULES, SQLGenPostProcessor, - calculated_field_instructions, construct_instructions, - json_field_instructions, - metric_instructions, + get_calculated_field_instructions, + get_json_field_instructions, + get_metric_instructions, + get_text_to_sql_rules, ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost @@ -26,14 +26,17 @@ logger = logging.getLogger("wren-ai-service") -sql_regeneration_system_prompt = f""" +def get_sql_regeneration_system_prompt() -> str: + text_to_sql_rules = get_text_to_sql_rules() + + return f""" ### TASK ### You are a great ANSI SQL expert. Now you are given database schema, SQL generation reasoning and an original SQL query, please carefully review the reasoning, and then generate a new SQL query that matches the reasoning. While generating the new SQL query, you should use the original SQL query as a reference. While generating the new SQL query, make sure to use the database schema to generate the SQL query. -{TEXT_TO_SQL_RULES} +{text_to_sql_rules} ### FINAL ANSWER FORMAT ### The final answer must be a ANSI SQL query in JSON format: @@ -43,6 +46,7 @@ }} """ + sql_regeneration_user_prompt_template = """ ### DATABASE SCHEMA ### {% for document in documents %} @@ -115,10 +119,12 @@ def prompt( instructions=instructions, ), calculated_field_instructions=( - calculated_field_instructions if has_calculated_field else "" + get_calculated_field_instructions() if has_calculated_field else "" + ), + metric_instructions=(get_metric_instructions() if has_metric else ""), + json_field_instructions=( + get_json_field_instructions() if has_json_field else "" ), - metric_instructions=(metric_instructions if has_metric else ""), - json_field_instructions=(json_field_instructions if has_json_field else ""), sql_samples=sql_samples, sql_functions=sql_functions, ) @@ -157,9 +163,11 @@ def __init__( engine: Engine, **kwargs, ): + self._llm_provider = llm_provider + self._components = { "generator": llm_provider.get_generator( - system_prompt=sql_regeneration_system_prompt, + system_prompt=get_sql_regeneration_system_prompt(), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), "generator_name": llm_provider.get_model(), @@ -188,6 +196,12 @@ async def run( sql_functions: list[SqlFunction] | None = None, ): logger.info("SQL Regeneration pipeline is running...") + + self._components["generator"] = self._llm_provider.get_generator( + system_prompt=get_sql_regeneration_system_prompt(), + generation_kwargs=SQL_GENERATION_MODEL_KWARGS, + ) + return await self._pipe.execute( ["post_process"], inputs={ diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index 0f1ddfc93c..c166920ab4 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -162,7 +162,7 @@ async def _classify_generation_result( return valid_generation_result, invalid_generation_result -TEXT_TO_SQL_RULES = """ +_DEFAULT_TEXT_TO_SQL_RULES = """ ### SQL RULES ### - ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database. - ONLY USE the tables and columns mentioned in the database schema. @@ -222,7 +222,7 @@ async def _classify_generation_result( """ -calculated_field_instructions = """ +_DEFAULT_CALCULATED_FIELD_INSTRUCTIONS = """ #### Instructions for Calculated Field #### The first structure is the special column marked as "Calculated Field". You need to interpret the purpose and calculation basis for these columns, then utilize them in the following text-to-sql generation tasks. @@ -269,7 +269,7 @@ async def _classify_generation_result( SQL Query: SELECT AVG(Rating) FROM orders WHERE ReviewCount > 10 """ -metric_instructions = """ +_DEFAULT_METRIC_INSTRUCTIONS = """ #### Instructions for Metric #### Second, you will learn how to effectively utilize the special "metric" structure in text-to-SQL generation tasks. @@ -360,7 +360,7 @@ async def _classify_generation_result( PurchaseTimestamp < DATE_TRUNC('month', CURRENT_DATE) """ -json_field_instructions = """ +_DEFAULT_JSON_FIELD_INSTRUCTIONS = """ #### Instructions for JSON related functions #### - ONLY USE JSON_QUERY for querying fields if "json_type":"JSON" is identified in the columns comment, NOT the deprecated JSON_EXTRACT_SCALAR function. - DON'T USE CAST for JSON fields, ONLY USE the following funtions: @@ -422,29 +422,6 @@ async def _classify_generation_result( Learn about the usage of the schema structures and generate SQL based on them. """ -sql_generation_system_prompt = f""" -You are a helpful assistant that converts natural language queries into ANSI SQL queries. - -Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. - -### GENERAL RULES ### - -1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. -2. YOU MUST ONLY CHOOSE the appropriate functions from the sql functions list and use them in the SQL query if the section of SQL FUNCTIONS is available in user's input. -3. YOU MUST REFER to the sql samples and learn the usage of the schema structures and how SQL is written based on them if the section of SQL SAMPLES is available in user's input. -4. YOU MUST FOLLOW the reasoning plan step by step strictly to generate the SQL query if the section of REASONING PLAN is available in user's input. -5. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. - -{TEXT_TO_SQL_RULES} - -### FINAL ANSWER FORMAT ### -The final answer must be a ANSI SQL query in JSON format: - -{{ - "sql": -}} -""" - sql_generation_reasoning_system_prompt = """ ### TASK ### @@ -473,6 +450,113 @@ async def _classify_generation_result( """ +TEXT_TO_SQL_RULES: str | None = None +CALCULATED_FIELD_INSTRUCTIONS: str | None = None +METRIC_INSTRUCTIONS: str | None = None +JSON_FIELD_INSTRUCTIONS: str | None = None + + +def set_sql_knowledge(sql_knowledge=None): + global \ + TEXT_TO_SQL_RULES, \ + CALCULATED_FIELD_INSTRUCTIONS, \ + METRIC_INSTRUCTIONS, \ + JSON_FIELD_INSTRUCTIONS + + if sql_knowledge is not None: + from src.pipelines.retrieval.sql_knowledge import SqlKnowledge + + if isinstance(sql_knowledge, SqlKnowledge): + text_to_sql_rule = sql_knowledge.text_to_sql_rule + if text_to_sql_rule and text_to_sql_rule.strip(): + TEXT_TO_SQL_RULES = text_to_sql_rule + else: + TEXT_TO_SQL_RULES = _DEFAULT_TEXT_TO_SQL_RULES + + calculated_field_instruction = sql_knowledge.calculated_field_instructions + if calculated_field_instruction and calculated_field_instruction.strip(): + CALCULATED_FIELD_INSTRUCTIONS = calculated_field_instruction + else: + CALCULATED_FIELD_INSTRUCTIONS = _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS + + metric_instruction = sql_knowledge.metric_instructions + if metric_instruction and metric_instruction.strip(): + METRIC_INSTRUCTIONS = metric_instruction + else: + METRIC_INSTRUCTIONS = _DEFAULT_METRIC_INSTRUCTIONS + + json_field_instruction = sql_knowledge.json_field_instructions + if json_field_instruction and json_field_instruction.strip(): + JSON_FIELD_INSTRUCTIONS = json_field_instruction + else: + JSON_FIELD_INSTRUCTIONS = _DEFAULT_JSON_FIELD_INSTRUCTIONS + return + + TEXT_TO_SQL_RULES = _DEFAULT_TEXT_TO_SQL_RULES + CALCULATED_FIELD_INSTRUCTIONS = _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS + METRIC_INSTRUCTIONS = _DEFAULT_METRIC_INSTRUCTIONS + JSON_FIELD_INSTRUCTIONS = _DEFAULT_JSON_FIELD_INSTRUCTIONS + + +def get_text_to_sql_rules() -> str: + return ( + TEXT_TO_SQL_RULES + if TEXT_TO_SQL_RULES is not None + else _DEFAULT_TEXT_TO_SQL_RULES + ) + + +def get_calculated_field_instructions() -> str: + return ( + CALCULATED_FIELD_INSTRUCTIONS + if CALCULATED_FIELD_INSTRUCTIONS is not None + else _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS + ) + + +def get_metric_instructions() -> str: + return ( + METRIC_INSTRUCTIONS + if METRIC_INSTRUCTIONS is not None + else _DEFAULT_METRIC_INSTRUCTIONS + ) + + +def get_json_field_instructions() -> str: + return ( + JSON_FIELD_INSTRUCTIONS + if JSON_FIELD_INSTRUCTIONS is not None + else _DEFAULT_JSON_FIELD_INSTRUCTIONS + ) + + +def get_sql_generation_system_prompt() -> str: + text_to_sql_rules = get_text_to_sql_rules() + + return f""" +You are a helpful assistant that converts natural language queries into ANSI SQL queries. + +Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. + +### GENERAL RULES ### + +1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. +2. YOU MUST ONLY CHOOSE the appropriate functions from the sql functions list and use them in the SQL query if the section of SQL FUNCTIONS is available in user's input. +3. YOU MUST REFER to the sql samples and learn the usage of the schema structures and how SQL is written based on them if the section of SQL SAMPLES is available in user's input. +4. YOU MUST FOLLOW the reasoning plan step by step strictly to generate the SQL query if the section of REASONING PLAN is available in user's input. +5. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. + +{text_to_sql_rules} + +### FINAL ANSWER FORMAT ### +The final answer must be a ANSI SQL query in JSON format: + +{{ + "sql": +}} +""" + + class SqlGenerationResult(BaseModel): sql: str From 7abf4b071a894611866f3346c14fb7b4fdf9c8c9 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 16:54:30 +0800 Subject: [PATCH 08/18] Add SQL knowledge retrieval configuration to settings and service container This update introduces the `allow_sql_knowledge_retrieval` setting to the configuration, enabling the integration of SQL knowledge retrieval within the service container. This enhancement improves the overall functionality and data processing capabilities of the AI service. --- wren-ai-service/src/config.py | 1 + wren-ai-service/src/globals.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index 08ff7b7542..20637032a4 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -41,6 +41,7 @@ class Settings(BaseSettings): allow_sql_generation_reasoning: bool = Field(default=True) allow_sql_functions_retrieval: bool = Field(default=True) allow_sql_diagnosis: bool = Field(default=True) + allow_sql_knowledge_retrieval: bool = Field(default=True) max_histories: int = Field(default=5) max_sql_correction_retries: int = Field(default=3) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 92029c5082..24a562e81d 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -158,6 +158,7 @@ def create_service_container( allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning, allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval, allow_sql_diagnosis=settings.allow_sql_diagnosis, + allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval, max_histories=settings.max_histories, enable_column_pruning=settings.enable_column_pruning, max_sql_correction_retries=settings.max_sql_correction_retries, @@ -174,9 +175,13 @@ def create_service_container( ), "sql_correction": _sql_correction_pipeline, "sql_diagnosis": _sql_diagnosis_pipeline, + "sql_knowledge_retrieval": retrieval.SqlKnowledges( + **pipe_components["sql_knowledge_retrieval"], + ), }, allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval, allow_sql_diagnosis=settings.allow_sql_diagnosis, + allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval, **query_cache, ), chart_service=services.ChartService( @@ -228,8 +233,12 @@ def create_service_container( "sql_pairs_retrieval": _sql_pair_retrieval_pipeline, "instructions_retrieval": _instructions_retrieval_pipeline, "sql_functions_retrieval": _sql_functions_retrieval_pipeline, + "sql_knowledge_retrieval": retrieval.SqlKnowledges( + **pipe_components["sql_knowledge_retrieval"], + ), }, allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval, + allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval, **query_cache, ), sql_pairs_service=services.SqlPairsService( @@ -259,7 +268,11 @@ def create_service_container( ), "db_schema_retrieval": _db_schema_retrieval_pipeline, "sql_correction": _sql_correction_pipeline, + "sql_knowledge_retrieval": retrieval.SqlKnowledges( + **pipe_components["sql_knowledge_retrieval"], + ), }, + allow_sql_knowledge_retrieval=settings.allow_sql_knowledge_retrieval, **query_cache, ), ) From b5029fb77ceba0229bb331a5239228cb2e5f7ace Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 16:58:20 +0800 Subject: [PATCH 09/18] Add support for SQL knowledge retrieval in multiple services This update introduces the `allow_sql_knowledge_retrieval` parameter across various services, including AskFeedbackService, AskService, QuestionRecommendation, and SqlCorrectionService. This enhancement allows for conditional execution of SQL knowledge retrieval, improving the flexibility and functionality of the AI service's data processing capabilities. --- wren-ai-service/src/web/v1/services/ask.py | 7 +++++++ wren-ai-service/src/web/v1/services/ask_feedback.py | 7 +++++++ .../src/web/v1/services/question_recommendation.py | 7 +++++++ wren-ai-service/src/web/v1/services/sql_corrections.py | 7 +++++++ 4 files changed, 28 insertions(+) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index a6a7f8bff6..676467071f 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -102,6 +102,7 @@ def __init__( allow_sql_generation_reasoning: bool = True, allow_sql_functions_retrieval: bool = True, allow_sql_diagnosis: bool = True, + allow_sql_knowledge_retrieval: bool = True, enable_column_pruning: bool = False, max_sql_correction_retries: int = 3, max_histories: int = 5, @@ -116,6 +117,7 @@ def __init__( self._allow_sql_functions_retrieval = allow_sql_functions_retrieval self._allow_intent_classification = allow_intent_classification self._allow_sql_diagnosis = allow_sql_diagnosis + self._allow_sql_knowledge_retrieval = allow_sql_knowledge_retrieval self._enable_column_pruning = enable_column_pruning self._max_histories = max_histories self._max_sql_correction_retries = max_sql_correction_retries @@ -339,6 +341,11 @@ async def ask( is_followup=True if histories else False, ) + if self._allow_sql_knowledge_retrieval: + await self._pipelines["sql_knowledge_retrieval"].run( + project_id=ask_request.project_id, + ) + retrieval_result = await self._pipelines["db_schema_retrieval"].run( query=user_query, histories=histories, diff --git a/wren-ai-service/src/web/v1/services/ask_feedback.py b/wren-ai-service/src/web/v1/services/ask_feedback.py index 80e6fc809f..8ff35ac975 100644 --- a/wren-ai-service/src/web/v1/services/ask_feedback.py +++ b/wren-ai-service/src/web/v1/services/ask_feedback.py @@ -59,6 +59,7 @@ class AskFeedbackService: def __init__( self, pipelines: Dict[str, BasicPipeline], + allow_sql_knowledge_retrieval: bool = True, allow_sql_functions_retrieval: bool = True, allow_sql_diagnosis: bool = True, maxsize: int = 1_000_000, @@ -68,6 +69,7 @@ def __init__( self._ask_feedback_results: Dict[str, AskFeedbackResultResponse] = TTLCache( maxsize=maxsize, ttl=ttl ) + self._allow_sql_knowledge_retrieval = allow_sql_knowledge_retrieval self._allow_sql_functions_retrieval = allow_sql_functions_retrieval self._allow_sql_diagnosis = allow_sql_diagnosis @@ -161,6 +163,11 @@ async def ask_feedback( trace_id=trace_id, ) + if self._allow_sql_knowledge_retrieval: + await self._pipelines["sql_knowledge_retrieval"].run( + project_id=ask_feedback_request.project_id, + ) + text_to_sql_generation_results = await self._pipelines[ "sql_regeneration" ].run( diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index 24df1e3e62..75fb41f012 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -33,12 +33,14 @@ def __init__( allow_sql_functions_retrieval: bool = True, maxsize: int = 1_000_000, ttl: int = 120, + allow_sql_knowledge_retrieval: bool = True, ): self._pipelines = pipelines self._cache: Dict[str, QuestionRecommendation.Event] = TTLCache( maxsize=maxsize, ttl=ttl ) self._allow_sql_functions_retrieval = allow_sql_functions_retrieval + self._allow_sql_knowledge_retrieval = allow_sql_knowledge_retrieval def _handle_exception( self, @@ -112,6 +114,11 @@ async def _instructions_retrieval() -> list[dict]: else: sql_functions = [] + if self._allow_sql_knowledge_retrieval: + await self._pipelines["sql_knowledge_retrieval"].run( + project_id=project_id, + ) + generated_sql = await self._pipelines["sql_generation"].run( query=candidate["question"], contexts=table_ddls, diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py index 116330ec57..b68bc3e214 100644 --- a/wren-ai-service/src/web/v1/services/sql_corrections.py +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -31,9 +31,11 @@ def __init__( pipelines: dict[str, BasicPipeline], maxsize: int = 1_000_000, ttl: int = 120, + allow_sql_knowledge_retrieval: bool = True, ): self._pipelines = pipelines self._cache: dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl) + self._allow_sql_knowledge_retrieval = allow_sql_knowledge_retrieval def _handle_exception( self, @@ -92,6 +94,11 @@ async def correct( ) )["post_process"] + if self._allow_sql_knowledge_retrieval: + await self._pipelines["sql_knowledge_retrieval"].run( + project_id=project_id, + ) + documents = ( ( await self._pipelines["db_schema_retrieval"].run( From 5f548dbe9fae7c3b1336a46417ca02f0b3359031 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 17:02:50 +0800 Subject: [PATCH 10/18] Implement SQL knowledge retrieval pipeline and refactor SQL generation class This update introduces the `SqlKnowledge` class and the `SqlKnowledges` pipeline for enhanced SQL knowledge retrieval. Additionally, a print statement in the `SQLGeneration` class has been removed to streamline the code. These changes improve the modularity and functionality of the AI service's data processing capabilities. --- .../pipelines/generation/sql_generation.py | 1 - .../src/pipelines/retrieval/sql_knowledge.py | 140 ++++++++++++++++++ 2 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 wren-ai-service/src/pipelines/retrieval/sql_knowledge.py diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 7c09e6dee8..3869528255 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -172,7 +172,6 @@ def __init__( ), "post_processor": SQLGenPostProcessor(engine=engine), } - print("get_sql_generation_system_prompt", get_sql_generation_system_prompt()) super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) diff --git a/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py b/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py new file mode 100644 index 0000000000..aff7712825 --- /dev/null +++ b/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py @@ -0,0 +1,140 @@ +import logging +import sys +from typing import Dict, Optional + +import aiohttp +from cachetools import TTLCache +from hamilton import base +from hamilton.async_driver import AsyncDriver +from langfuse.decorators import observe + +from src.core.engine import Engine +from src.core.pipeline import BasicPipeline +from src.core.provider import DocumentStoreProvider +from src.pipelines.common import retrieve_metadata +from src.pipelines.generation.utils.sql import set_sql_knowledge +from src.providers.engine.wren import WrenIbis + +logger = logging.getLogger("wren-ai-service") + + +class SqlKnowledge: + _data: Dict = None + + def __init__(self, sql_knowledge: dict): + self._data = sql_knowledge + + @classmethod + def empty(cls, sql_knowledge: dict): + return ( + not sql_knowledge + or not sql_knowledge.get("text_to_sql_rule") + or not sql_knowledge.get("instructions") + ) + + @property + def text_to_sql_rule(self) -> str: + return self._data.get("text_to_sql_rule", "") + + @property + def instructions(self) -> dict: + return self._data.get("instructions", {}) + + @property + def calculated_field_instructions(self) -> str: + return self.instructions.get("calculated_field_instructions", "") + + @property + def metric_instructions(self) -> str: + return self.instructions.get("metric_instructions", "") + + @property + def json_field_instructions(self) -> str: + return self.instructions.get("json_field_instructions", "") + + def __str__(self): + return f"text_to_sql_rule: {self.text_to_sql_rule}, instructions: {self.instructions}" + + def __repr__(self): + return self.__str__() + + +## Start of Pipeline +@observe(capture_input=False) +async def get_knowledge( + engine: WrenIbis, + data_source: str, +) -> Optional[SqlKnowledge]: + async with aiohttp.ClientSession() as session: + knowledge_dict = await engine.get_sql_knowledge( + session=session, + data_source=data_source, + ) + + if not knowledge_dict or SqlKnowledge.empty(knowledge_dict): + return None + + return SqlKnowledge(sql_knowledge=knowledge_dict) + + +@observe(capture_input=False) +def cache( + data_source: str, + get_knowledge: Optional[SqlKnowledge], + ttl_cache: TTLCache, +) -> Optional[SqlKnowledge]: + if get_knowledge: + ttl_cache[data_source] = get_knowledge + set_sql_knowledge(get_knowledge) + + return + + +## End of Pipeline + + +class SqlKnowledges(BasicPipeline): + def __init__( + self, + engine: Engine, + document_store_provider: DocumentStoreProvider, + ttl: int = 60 * 60 * 24, + **kwargs, + ) -> None: + self._retriever = document_store_provider.get_retriever( + document_store_provider.get_store("project_meta") + ) + self._cache = TTLCache(maxsize=100, ttl=ttl) + self._components = { + "engine": engine, + "ttl_cache": self._cache, + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + @observe(name="SQL Knowledge Retrieval") + async def run( + self, + project_id: Optional[str] = None, + ) -> Optional[SqlKnowledge]: + logger.info( + f"Project ID: {project_id} SQL Knowledge Retrieval pipeline is running..." + ) + + metadata = await retrieve_metadata(project_id or "", self._retriever) + _data_source = metadata.get("data_source", "local_file") + + if _data_source in self._cache: + logger.info(f"Hit cache of SQL Knowledge for {_data_source}") + set_sql_knowledge(self._cache[_data_source]) + return self._cache[_data_source] + + input = { + "data_source": _data_source, + "project_id": project_id, + **self._components, + } + result = await self._pipe.execute(["cache"], inputs=input) + return result["cache"] From 25b6bd007170ccb8a0ff1d6d9d3ec6ef5678c9db Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 17:08:40 +0800 Subject: [PATCH 11/18] Remove generator initialization from SQL generation classes to streamline component setup. This change enhances modularity by focusing on the model name and prompt builder, improving maintainability across the SQL generation pipeline. --- .../src/pipelines/generation/followup_sql_generation.py | 4 ---- wren-ai-service/src/pipelines/generation/sql_correction.py | 4 ---- wren-ai-service/src/pipelines/generation/sql_generation.py | 4 ---- wren-ai-service/src/pipelines/generation/sql_regeneration.py | 4 ---- 4 files changed, 16 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 3b9493447c..6e4a35866d 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -165,10 +165,6 @@ def __init__( self._llm_provider = llm_provider self._components = { - "generator": llm_provider.get_generator( - system_prompt=get_sql_generation_system_prompt(), - generation_kwargs=SQL_GENERATION_MODEL_KWARGS, - ), "generator_name": llm_provider.get_model(), "prompt_builder": PromptBuilder( template=text_to_sql_with_followup_user_prompt_template diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index e1d6180d66..f5123c616d 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -143,10 +143,6 @@ def __init__( self._llm_provider = llm_provider self._components = { - "generator": llm_provider.get_generator( - system_prompt=get_sql_correction_system_prompt(), - generation_kwargs=SQL_GENERATION_MODEL_KWARGS, - ), "generator_name": llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_correction_user_prompt_template diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 3869528255..7a7b66f080 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -162,10 +162,6 @@ def __init__( self._llm_provider = llm_provider self._components = { - "generator": llm_provider.get_generator( - system_prompt=get_sql_generation_system_prompt(), - generation_kwargs=SQL_GENERATION_MODEL_KWARGS, - ), "generator_name": llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_generation_user_prompt_template diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index b6cebb530f..eb99492e66 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -166,10 +166,6 @@ def __init__( self._llm_provider = llm_provider self._components = { - "generator": llm_provider.get_generator( - system_prompt=get_sql_regeneration_system_prompt(), - generation_kwargs=SQL_GENERATION_MODEL_KWARGS, - ), "generator_name": llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_regeneration_user_prompt_template From a129c8f6f9c70aee110eb06a318a6cae769c9df4 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 17 Nov 2025 23:36:44 +0800 Subject: [PATCH 12/18] Enhance SQL generation and correction pipelines with SqlKnowledge integration This update modifies the SQL generation, correction, and regeneration classes to utilize the `SqlKnowledge` class for improved instruction retrieval. The changes include passing `sql_knowledge` to various functions and methods, ensuring that SQL-related instructions are dynamically fetched based on the available knowledge. This enhancement improves the flexibility and accuracy of SQL query generation and correction processes across the AI service. --- .../generation/followup_sql_generation.py | 16 +++- .../pipelines/generation/sql_correction.py | 8 +- .../pipelines/generation/sql_generation.py | 21 +++-- .../pipelines/generation/sql_regeneration.py | 22 ++++-- .../src/pipelines/generation/utils/sql.py | 79 ++++++++----------- .../src/pipelines/retrieval/sql_knowledge.py | 5 +- wren-ai-service/src/web/v1/services/ask.py | 17 ++-- .../src/web/v1/services/ask_feedback.py | 16 ++-- .../v1/services/question_recommendation.py | 5 +- .../src/web/v1/services/sql_corrections.py | 4 +- 10 files changed, 115 insertions(+), 78 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 6e4a35866d..ebc4a1dffe 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -22,6 +22,7 @@ get_sql_generation_system_prompt, ) from src.pipelines.retrieval.sql_functions import SqlFunction +from src.pipelines.retrieval.sql_knowledge import SqlKnowledge from src.utils import trace_cost from src.web.v1.services.ask import AskHistory @@ -97,6 +98,7 @@ def prompt( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + sql_knowledge: SqlKnowledge | None = None, ) -> dict: _prompt = prompt_builder.run( query=query, @@ -106,11 +108,15 @@ def prompt( instructions=instructions, ), calculated_field_instructions=( - get_calculated_field_instructions() if has_calculated_field else "" + get_calculated_field_instructions(sql_knowledge) + if has_calculated_field + else "" + ), + metric_instructions=( + get_metric_instructions(sql_knowledge) if has_metric else "" ), - metric_instructions=(get_metric_instructions() if has_metric else ""), json_field_instructions=( - get_json_field_instructions() if has_json_field else "" + get_json_field_instructions(sql_knowledge) if has_json_field else "" ), sql_samples=sql_samples, sql_functions=sql_functions, @@ -192,6 +198,7 @@ async def run( sql_functions: list[SqlFunction] | None = None, use_dry_plan: bool = False, allow_dry_plan_fallback: bool = True, + sql_knowledge: SqlKnowledge | None = None, ): logger.info("Follow-Up SQL Generation pipeline is running...") @@ -201,7 +208,7 @@ async def run( metadata = {} self._components["generator"] = self._llm_provider.get_generator( - system_prompt=get_sql_generation_system_prompt(), + system_prompt=get_sql_generation_system_prompt(sql_knowledge), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ) @@ -222,6 +229,7 @@ async def run( "use_dry_plan": use_dry_plan, "allow_dry_plan_fallback": allow_dry_plan_fallback, "data_source": metadata.get("data_source", "local_file"), + "sql_knowledge": sql_knowledge, **self._components, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index f5123c616d..0d50a628f4 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -19,13 +19,14 @@ get_text_to_sql_rules, ) from src.pipelines.retrieval.sql_functions import SqlFunction +from src.pipelines.retrieval.sql_knowledge import SqlKnowledge from src.utils import trace_cost logger = logging.getLogger("wren-ai-service") -def get_sql_correction_system_prompt() -> str: - text_to_sql_rules = get_text_to_sql_rules() +def get_sql_correction_system_prompt(sql_knowledge: SqlKnowledge | None = None) -> str: + text_to_sql_rules = get_text_to_sql_rules(sql_knowledge) return f""" ### TASK ### @@ -164,11 +165,12 @@ async def run( project_id: str | None = None, use_dry_plan: bool = False, allow_dry_plan_fallback: bool = True, + sql_knowledge: SqlKnowledge | None = None, ): logger.info("SQLCorrection pipeline is running...") self._components["generator"] = self._llm_provider.get_generator( - system_prompt=get_sql_correction_system_prompt(), + system_prompt=get_sql_correction_system_prompt(sql_knowledge), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 7a7b66f080..0a726119da 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -21,6 +21,7 @@ get_sql_generation_system_prompt, ) from src.pipelines.retrieval.sql_functions import SqlFunction +from src.pipelines.retrieval.sql_knowledge import SqlKnowledge from src.utils import trace_cost logger = logging.getLogger("wren-ai-service") @@ -93,6 +94,7 @@ def prompt( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + sql_knowledge: SqlKnowledge | None = None, ) -> dict: _prompt = prompt_builder.run( query=query, @@ -102,11 +104,15 @@ def prompt( instructions=instructions, ), calculated_field_instructions=( - get_calculated_field_instructions() if has_calculated_field else "" + get_calculated_field_instructions(sql_knowledge) + if has_calculated_field + else "" + ), + metric_instructions=( + get_metric_instructions(sql_knowledge) if has_metric else "" ), - metric_instructions=(get_metric_instructions() if has_metric else ""), json_field_instructions=( - get_json_field_instructions() if has_json_field else "" + get_json_field_instructions(sql_knowledge) if has_json_field else "" ), sql_samples=sql_samples, sql_functions=sql_functions, @@ -189,6 +195,7 @@ async def run( use_dry_plan: bool = False, allow_dry_plan_fallback: bool = True, allow_data_preview: bool = False, + sql_knowledge: SqlKnowledge | None = None, ): logger.info("SQL Generation pipeline is running...") @@ -198,10 +205,13 @@ async def run( metadata = {} self._components["generator"] = self._llm_provider.get_generator( - system_prompt=get_sql_generation_system_prompt(), + system_prompt=get_sql_generation_system_prompt(sql_knowledge), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ) - + print( + "get_sql_generation_system_prompt(sql_knowledge):", + get_sql_generation_system_prompt(sql_knowledge), + ) return await self._pipe.execute( ["post_process"], inputs={ @@ -219,6 +229,7 @@ async def run( "allow_dry_plan_fallback": allow_dry_plan_fallback, "data_source": metadata.get("data_source", "local_file"), "allow_data_preview": allow_data_preview, + "sql_knowledge": sql_knowledge, **self._components, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index eb99492e66..5e074016a3 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -21,13 +21,16 @@ get_text_to_sql_rules, ) from src.pipelines.retrieval.sql_functions import SqlFunction +from src.pipelines.retrieval.sql_knowledge import SqlKnowledge from src.utils import trace_cost logger = logging.getLogger("wren-ai-service") -def get_sql_regeneration_system_prompt() -> str: - text_to_sql_rules = get_text_to_sql_rules() +def get_sql_regeneration_system_prompt( + sql_knowledge: SqlKnowledge | None = None, +) -> str: + text_to_sql_rules = get_text_to_sql_rules(sql_knowledge) return f""" ### TASK ### @@ -110,6 +113,7 @@ def prompt( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + sql_knowledge: SqlKnowledge | None = None, ) -> dict: _prompt = prompt_builder.run( sql=sql, @@ -119,11 +123,15 @@ def prompt( instructions=instructions, ), calculated_field_instructions=( - get_calculated_field_instructions() if has_calculated_field else "" + get_calculated_field_instructions(sql_knowledge) + if has_calculated_field + else "" + ), + metric_instructions=( + get_metric_instructions(sql_knowledge) if has_metric else "" ), - metric_instructions=(get_metric_instructions() if has_metric else ""), json_field_instructions=( - get_json_field_instructions() if has_json_field else "" + get_json_field_instructions(sql_knowledge) if has_json_field else "" ), sql_samples=sql_samples, sql_functions=sql_functions, @@ -190,11 +198,12 @@ async def run( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + sql_knowledge: SqlKnowledge | None = None, ): logger.info("SQL Regeneration pipeline is running...") self._components["generator"] = self._llm_provider.get_generator( - system_prompt=get_sql_regeneration_system_prompt(), + system_prompt=get_sql_regeneration_system_prompt(sql_knowledge), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ) @@ -211,6 +220,7 @@ async def run( "has_metric": has_metric, "has_json_field": has_json_field, "sql_functions": sql_functions, + "sql_knowledge": sql_knowledge, **self._components, }, ) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index c166920ab4..dd427ab9c9 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -11,6 +11,7 @@ Engine, clean_generation_result, ) +from src.pipelines.retrieval.sql_knowledge import SqlKnowledge from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -456,49 +457,22 @@ async def _classify_generation_result( JSON_FIELD_INSTRUCTIONS: str | None = None -def set_sql_knowledge(sql_knowledge=None): - global \ - TEXT_TO_SQL_RULES, \ - CALCULATED_FIELD_INSTRUCTIONS, \ - METRIC_INSTRUCTIONS, \ - JSON_FIELD_INSTRUCTIONS - - if sql_knowledge is not None: - from src.pipelines.retrieval.sql_knowledge import SqlKnowledge +def _extract_from_sql_knowledge( + sql_knowledge: SqlKnowledge | None, attribute_name: str, default_value: str +) -> str: + if sql_knowledge is None: + return default_value - if isinstance(sql_knowledge, SqlKnowledge): - text_to_sql_rule = sql_knowledge.text_to_sql_rule - if text_to_sql_rule and text_to_sql_rule.strip(): - TEXT_TO_SQL_RULES = text_to_sql_rule - else: - TEXT_TO_SQL_RULES = _DEFAULT_TEXT_TO_SQL_RULES + value = getattr(sql_knowledge, attribute_name, "") + return value if value and value.strip() else default_value - calculated_field_instruction = sql_knowledge.calculated_field_instructions - if calculated_field_instruction and calculated_field_instruction.strip(): - CALCULATED_FIELD_INSTRUCTIONS = calculated_field_instruction - else: - CALCULATED_FIELD_INSTRUCTIONS = _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS - - metric_instruction = sql_knowledge.metric_instructions - if metric_instruction and metric_instruction.strip(): - METRIC_INSTRUCTIONS = metric_instruction - else: - METRIC_INSTRUCTIONS = _DEFAULT_METRIC_INSTRUCTIONS - - json_field_instruction = sql_knowledge.json_field_instructions - if json_field_instruction and json_field_instruction.strip(): - JSON_FIELD_INSTRUCTIONS = json_field_instruction - else: - JSON_FIELD_INSTRUCTIONS = _DEFAULT_JSON_FIELD_INSTRUCTIONS - return - - TEXT_TO_SQL_RULES = _DEFAULT_TEXT_TO_SQL_RULES - CALCULATED_FIELD_INSTRUCTIONS = _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS - METRIC_INSTRUCTIONS = _DEFAULT_METRIC_INSTRUCTIONS - JSON_FIELD_INSTRUCTIONS = _DEFAULT_JSON_FIELD_INSTRUCTIONS +def get_text_to_sql_rules(sql_knowledge: SqlKnowledge | None = None) -> str: + if sql_knowledge is not None: + return _extract_from_sql_knowledge( + sql_knowledge, "text_to_sql_rule", _DEFAULT_TEXT_TO_SQL_RULES + ) -def get_text_to_sql_rules() -> str: return ( TEXT_TO_SQL_RULES if TEXT_TO_SQL_RULES is not None @@ -506,7 +480,14 @@ def get_text_to_sql_rules() -> str: ) -def get_calculated_field_instructions() -> str: +def get_calculated_field_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: + if sql_knowledge is not None: + return _extract_from_sql_knowledge( + sql_knowledge, + "calculated_field_instructions", + _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS, + ) + return ( CALCULATED_FIELD_INSTRUCTIONS if CALCULATED_FIELD_INSTRUCTIONS is not None @@ -514,7 +495,12 @@ def get_calculated_field_instructions() -> str: ) -def get_metric_instructions() -> str: +def get_metric_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: + if sql_knowledge is not None: + return _extract_from_sql_knowledge( + sql_knowledge, "metric_instructions", _DEFAULT_METRIC_INSTRUCTIONS + ) + return ( METRIC_INSTRUCTIONS if METRIC_INSTRUCTIONS is not None @@ -522,7 +508,12 @@ def get_metric_instructions() -> str: ) -def get_json_field_instructions() -> str: +def get_json_field_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: + if sql_knowledge is not None: + return _extract_from_sql_knowledge( + sql_knowledge, "json_field_instructions", _DEFAULT_JSON_FIELD_INSTRUCTIONS + ) + return ( JSON_FIELD_INSTRUCTIONS if JSON_FIELD_INSTRUCTIONS is not None @@ -530,8 +521,8 @@ def get_json_field_instructions() -> str: ) -def get_sql_generation_system_prompt() -> str: - text_to_sql_rules = get_text_to_sql_rules() +def get_sql_generation_system_prompt(sql_knowledge: SqlKnowledge | None = None) -> str: + text_to_sql_rules = get_text_to_sql_rules(sql_knowledge) return f""" You are a helpful assistant that converts natural language queries into ANSI SQL queries. diff --git a/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py b/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py index aff7712825..a44d8cb98b 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py @@ -12,7 +12,6 @@ from src.core.pipeline import BasicPipeline from src.core.provider import DocumentStoreProvider from src.pipelines.common import retrieve_metadata -from src.pipelines.generation.utils.sql import set_sql_knowledge from src.providers.engine.wren import WrenIbis logger = logging.getLogger("wren-ai-service") @@ -85,9 +84,8 @@ def cache( ) -> Optional[SqlKnowledge]: if get_knowledge: ttl_cache[data_source] = get_knowledge - set_sql_knowledge(get_knowledge) - return + return get_knowledge ## End of Pipeline @@ -128,7 +126,6 @@ async def run( if _data_source in self._cache: logger.info(f"Hit cache of SQL Knowledge for {_data_source}") - set_sql_knowledge(self._cache[_data_source]) return self._cache[_data_source] input = { diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 676467071f..aa26fa3f81 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -170,10 +170,12 @@ async def ask( ) allow_sql_functions_retrieval = self._allow_sql_functions_retrieval allow_sql_diagnosis = self._allow_sql_diagnosis + allow_sql_knowledge_retrieval = self._allow_sql_knowledge_retrieval max_sql_correction_retries = self._max_sql_correction_retries current_sql_correction_retries = 0 use_dry_plan = ask_request.use_dry_plan allow_dry_plan_fallback = ask_request.allow_dry_plan_fallback + sql_knowledge = None try: user_query = ask_request.query @@ -341,11 +343,6 @@ async def ask( is_followup=True if histories else False, ) - if self._allow_sql_knowledge_retrieval: - await self._pipelines["sql_knowledge_retrieval"].run( - project_id=ask_request.project_id, - ) - retrieval_result = await self._pipelines["db_schema_retrieval"].run( query=user_query, histories=histories, @@ -449,6 +446,13 @@ async def ask( else: sql_functions = [] + if allow_sql_knowledge_retrieval: + sql_knowledge = await self._pipelines[ + "sql_knowledge_retrieval" + ].run( + project_id=ask_request.project_id, + ) + has_calculated_field = _retrieval_result.get( "has_calculated_field", False ) @@ -472,6 +476,7 @@ async def ask( sql_functions=sql_functions, use_dry_plan=use_dry_plan, allow_dry_plan_fallback=allow_dry_plan_fallback, + sql_knowledge=sql_knowledge, ) else: text_to_sql_generation_results = await self._pipelines[ @@ -489,6 +494,7 @@ async def ask( sql_functions=sql_functions, use_dry_plan=use_dry_plan, allow_dry_plan_fallback=allow_dry_plan_fallback, + sql_knowledge=sql_knowledge, ) if sql_valid_result := text_to_sql_generation_results["post_process"][ @@ -554,6 +560,7 @@ async def ask( use_dry_plan=use_dry_plan, allow_dry_plan_fallback=allow_dry_plan_fallback, sql_functions=sql_functions, + sql_knowledge=sql_knowledge, ) if valid_generation_result := sql_correction_results[ diff --git a/wren-ai-service/src/web/v1/services/ask_feedback.py b/wren-ai-service/src/web/v1/services/ask_feedback.py index 8ff35ac975..9c5b06a772 100644 --- a/wren-ai-service/src/web/v1/services/ask_feedback.py +++ b/wren-ai-service/src/web/v1/services/ask_feedback.py @@ -104,6 +104,8 @@ async def ask_feedback( api_results = [] error_message = None invalid_sql = None + sql_knowledge = None + allow_sql_knowledge_retrieval = self._allow_sql_knowledge_retrieval try: if not self._is_stopped(query_id, self._ask_feedback_results): @@ -141,6 +143,13 @@ async def ask_feedback( else: sql_functions = [] + if allow_sql_knowledge_retrieval: + sql_knowledge = await self._pipelines[ + "sql_knowledge_retrieval" + ].run( + project_id=ask_feedback_request.project_id, + ) + # Extract results from completed tasks _retrieval_result = retrieval_task.get( "construct_retrieval_results", {} @@ -163,11 +172,6 @@ async def ask_feedback( trace_id=trace_id, ) - if self._allow_sql_knowledge_retrieval: - await self._pipelines["sql_knowledge_retrieval"].run( - project_id=ask_feedback_request.project_id, - ) - text_to_sql_generation_results = await self._pipelines[ "sql_regeneration" ].run( @@ -181,6 +185,7 @@ async def ask_feedback( has_metric=has_metric, has_json_field=has_json_field, sql_functions=sql_functions, + sql_knowledge=sql_knowledge, ) if sql_valid_result := text_to_sql_generation_results["post_process"][ @@ -235,6 +240,7 @@ async def ask_feedback( }, project_id=ask_feedback_request.project_id, sql_functions=sql_functions, + sql_knowledge=sql_knowledge, ) if valid_generation_result := sql_correction_results[ diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index 75fb41f012..6033237a45 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -115,9 +115,11 @@ async def _instructions_retrieval() -> list[dict]: sql_functions = [] if self._allow_sql_knowledge_retrieval: - await self._pipelines["sql_knowledge_retrieval"].run( + sql_knowledge = await self._pipelines["sql_knowledge_retrieval"].run( project_id=project_id, ) + else: + sql_knowledge = None generated_sql = await self._pipelines["sql_generation"].run( query=candidate["question"], @@ -130,6 +132,7 @@ async def _instructions_retrieval() -> list[dict]: has_json_field=has_json_field, sql_functions=sql_functions, allow_data_preview=allow_data_preview, + sql_knowledge=sql_knowledge, ) post_process = generated_sql["post_process"] diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py index b68bc3e214..86d0f55301 100644 --- a/wren-ai-service/src/web/v1/services/sql_corrections.py +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -80,6 +80,7 @@ async def correct( retrieved_tables = request.retrieved_tables use_dry_plan = request.use_dry_plan allow_dry_plan_fallback = request.allow_dry_plan_fallback + sql_knowledge = None try: _invalid = { @@ -95,7 +96,7 @@ async def correct( )["post_process"] if self._allow_sql_knowledge_retrieval: - await self._pipelines["sql_knowledge_retrieval"].run( + sql_knowledge = await self._pipelines["sql_knowledge_retrieval"].run( project_id=project_id, ) @@ -117,6 +118,7 @@ async def correct( project_id=project_id, use_dry_plan=use_dry_plan, allow_dry_plan_fallback=allow_dry_plan_fallback, + sql_knowledge=sql_knowledge, ) post_process = res["post_process"] From fceeb3eff989f693113458e75b52177817a91dd5 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Tue, 18 Nov 2025 00:02:48 +0800 Subject: [PATCH 13/18] Refactor SQLGeneration class to remove print statement and enhance code clarity. Update SqlKnowledge class to initialize _data directly from the constructor, improving type consistency. --- wren-ai-service/src/pipelines/generation/sql_generation.py | 5 +---- wren-ai-service/src/pipelines/retrieval/sql_knowledge.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 0a726119da..e74fa966e9 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -208,10 +208,7 @@ async def run( system_prompt=get_sql_generation_system_prompt(sql_knowledge), generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ) - print( - "get_sql_generation_system_prompt(sql_knowledge):", - get_sql_generation_system_prompt(sql_knowledge), - ) + return await self._pipe.execute( ["post_process"], inputs={ diff --git a/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py b/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py index a44d8cb98b..167969047c 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_knowledge.py @@ -18,10 +18,8 @@ class SqlKnowledge: - _data: Dict = None - def __init__(self, sql_knowledge: dict): - self._data = sql_knowledge + self._data: Dict = sql_knowledge @classmethod def empty(cls, sql_knowledge: dict): From aa1b5cc3fb883f3170df624e66bdec1e2da8bb38 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 18 Nov 2025 16:30:06 +0800 Subject: [PATCH 14/18] update --- wren-ai-service/tools/dev/.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/tools/dev/.env b/wren-ai-service/tools/dev/.env index 0b6825a77f..299cd0d30b 100644 --- a/wren-ai-service/tools/dev/.env +++ b/wren-ai-service/tools/dev/.env @@ -12,7 +12,7 @@ IBIS_SERVER_PORT=8000 # CHANGE THIS TO THE LATEST VERSION WREN_PRODUCT_VERSION=development WREN_ENGINE_VERSION=0.21.3 -WREN_AI_SERVICE_VERSION=0.27.14 +WREN_AI_SERVICE_VERSION=0.29.0 IBIS_SERVER_VERSION=0.21.3 WREN_UI_VERSION=0.31.3 WREN_BOOTSTRAP_VERSION=0.1.5 From 844a6b52ddc20a6d0bf5fa812b31cb38a2916da4 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 18 Nov 2025 17:49:16 +0800 Subject: [PATCH 15/18] update --- wren-ai-service/src/config.py | 2 +- .../src/pipelines/generation/utils/sql.py | 30 +++---------------- .../tools/config/config.example.yaml | 3 ++ 3 files changed, 8 insertions(+), 27 deletions(-) diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index 20637032a4..c5acf4ae47 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -41,7 +41,7 @@ class Settings(BaseSettings): allow_sql_generation_reasoning: bool = Field(default=True) allow_sql_functions_retrieval: bool = Field(default=True) allow_sql_diagnosis: bool = Field(default=True) - allow_sql_knowledge_retrieval: bool = Field(default=True) + allow_sql_knowledge_retrieval: bool = Field(default=False) max_histories: int = Field(default=5) max_sql_correction_retries: int = Field(default=3) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index dd427ab9c9..088282574e 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -451,12 +451,6 @@ async def _classify_generation_result( """ -TEXT_TO_SQL_RULES: str | None = None -CALCULATED_FIELD_INSTRUCTIONS: str | None = None -METRIC_INSTRUCTIONS: str | None = None -JSON_FIELD_INSTRUCTIONS: str | None = None - - def _extract_from_sql_knowledge( sql_knowledge: SqlKnowledge | None, attribute_name: str, default_value: str ) -> str: @@ -473,11 +467,7 @@ def get_text_to_sql_rules(sql_knowledge: SqlKnowledge | None = None) -> str: sql_knowledge, "text_to_sql_rule", _DEFAULT_TEXT_TO_SQL_RULES ) - return ( - TEXT_TO_SQL_RULES - if TEXT_TO_SQL_RULES is not None - else _DEFAULT_TEXT_TO_SQL_RULES - ) + return _DEFAULT_TEXT_TO_SQL_RULES def get_calculated_field_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: @@ -488,11 +478,7 @@ def get_calculated_field_instructions(sql_knowledge: SqlKnowledge | None = None) _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS, ) - return ( - CALCULATED_FIELD_INSTRUCTIONS - if CALCULATED_FIELD_INSTRUCTIONS is not None - else _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS - ) + return _DEFAULT_CALCULATED_FIELD_INSTRUCTIONS def get_metric_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: @@ -501,11 +487,7 @@ def get_metric_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: sql_knowledge, "metric_instructions", _DEFAULT_METRIC_INSTRUCTIONS ) - return ( - METRIC_INSTRUCTIONS - if METRIC_INSTRUCTIONS is not None - else _DEFAULT_METRIC_INSTRUCTIONS - ) + return _DEFAULT_METRIC_INSTRUCTIONS def get_json_field_instructions(sql_knowledge: SqlKnowledge | None = None) -> str: @@ -514,11 +496,7 @@ def get_json_field_instructions(sql_knowledge: SqlKnowledge | None = None) -> st sql_knowledge, "json_field_instructions", _DEFAULT_JSON_FIELD_INSTRUCTIONS ) - return ( - JSON_FIELD_INSTRUCTIONS - if JSON_FIELD_INSTRUCTIONS is not None - else _DEFAULT_JSON_FIELD_INSTRUCTIONS - ) + return _DEFAULT_JSON_FIELD_INSTRUCTIONS def get_sql_generation_system_prompt(sql_knowledge: SqlKnowledge | None = None) -> str: diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 74b0174e5a..dd3904529b 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -174,6 +174,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: From af3ec914bb0bd030c944fe2767a0777da5064501 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 19 Nov 2025 09:52:53 +0800 Subject: [PATCH 16/18] fix --- wren-ai-service/src/force_update_config.py | 5 ++++- wren-ai-service/src/providers/llm/litellm.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/wren-ai-service/src/force_update_config.py b/wren-ai-service/src/force_update_config.py index 0a3d7feba9..21c3456336 100644 --- a/wren-ai-service/src/force_update_config.py +++ b/wren-ai-service/src/force_update_config.py @@ -18,7 +18,10 @@ def update_config(): # Update engine name in all pipelines for pipe in doc.get("pipes", []): if "engine" in pipe: - if pipe["name"] == "sql_functions_retrieval": + if pipe["name"] in [ + "sql_functions_retrieval", + "sql_knowledge_retrieval", + ]: pipe["engine"] = "wren_ibis" else: pipe["engine"] = "wren_ui" diff --git a/wren-ai-service/src/providers/llm/litellm.py b/wren-ai-service/src/providers/llm/litellm.py index 3748a945da..6fe5722522 100644 --- a/wren-ai-service/src/providers/llm/litellm.py +++ b/wren-ai-service/src/providers/llm/litellm.py @@ -94,6 +94,8 @@ async def _run( convert_message_to_openai_format(message) for message in messages ] + print(f"prompt: {openai_formatted_messages}") + generation_kwargs = { **combined_generation_kwargs, **(generation_kwargs or {}), From 81c19675afb867982e84ae0d3b15992eb7c00cfc Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 19 Nov 2025 09:58:12 +0800 Subject: [PATCH 17/18] update --- docker/config.example.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 0b9c40204c..2a22e788e9 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -163,6 +163,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: From 4fa6c11a41c428287b56f6159bbfafa7e335d24d Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 11 Dec 2025 07:56:46 +0800 Subject: [PATCH 18/18] update configs --- wren-ai-service/docs/config_examples/config.anthropic.yaml | 4 +++- wren-ai-service/docs/config_examples/config.azure.yaml | 4 +++- wren-ai-service/docs/config_examples/config.bedrock.yaml | 4 +++- wren-ai-service/docs/config_examples/config.deepseek.yaml | 4 +++- .../docs/config_examples/config.google_ai_studio.yaml | 4 +++- .../docs/config_examples/config.google_vertexai.yaml | 4 +++- wren-ai-service/docs/config_examples/config.grok.yaml | 4 +++- wren-ai-service/docs/config_examples/config.groq.yaml | 4 +++- wren-ai-service/docs/config_examples/config.lm_studio.yaml | 4 +++- wren-ai-service/docs/config_examples/config.ollama.yaml | 4 +++- wren-ai-service/docs/config_examples/config.open_router.yaml | 4 +++- wren-ai-service/docs/config_examples/config.qwen3.yaml | 4 +++- wren-ai-service/docs/config_examples/config.zhipu.yaml | 4 +++- 13 files changed, 39 insertions(+), 13 deletions(-) diff --git a/wren-ai-service/docs/config_examples/config.anthropic.yaml b/wren-ai-service/docs/config_examples/config.anthropic.yaml index cc7ebe2d89..76e5d96526 100644 --- a/wren-ai-service/docs/config_examples/config.anthropic.yaml +++ b/wren-ai-service/docs/config_examples/config.anthropic.yaml @@ -142,7 +142,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.azure.yaml b/wren-ai-service/docs/config_examples/config.azure.yaml index 604f7b157b..9319394727 100644 --- a/wren-ai-service/docs/config_examples/config.azure.yaml +++ b/wren-ai-service/docs/config_examples/config.azure.yaml @@ -155,7 +155,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.bedrock.yaml b/wren-ai-service/docs/config_examples/config.bedrock.yaml index 7e3ac8f211..a94474de67 100644 --- a/wren-ai-service/docs/config_examples/config.bedrock.yaml +++ b/wren-ai-service/docs/config_examples/config.bedrock.yaml @@ -158,7 +158,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.deepseek.yaml b/wren-ai-service/docs/config_examples/config.deepseek.yaml index 899254474c..3a00b30ad8 100644 --- a/wren-ai-service/docs/config_examples/config.deepseek.yaml +++ b/wren-ai-service/docs/config_examples/config.deepseek.yaml @@ -165,7 +165,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.google_ai_studio.yaml b/wren-ai-service/docs/config_examples/config.google_ai_studio.yaml index 3cb6db70e0..9d087accb3 100644 --- a/wren-ai-service/docs/config_examples/config.google_ai_studio.yaml +++ b/wren-ai-service/docs/config_examples/config.google_ai_studio.yaml @@ -151,7 +151,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.google_vertexai.yaml b/wren-ai-service/docs/config_examples/config.google_vertexai.yaml index 262231d19f..0b29acb7e5 100644 --- a/wren-ai-service/docs/config_examples/config.google_vertexai.yaml +++ b/wren-ai-service/docs/config_examples/config.google_vertexai.yaml @@ -159,7 +159,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.grok.yaml b/wren-ai-service/docs/config_examples/config.grok.yaml index 1fe3ac6c91..b01de55b7e 100644 --- a/wren-ai-service/docs/config_examples/config.grok.yaml +++ b/wren-ai-service/docs/config_examples/config.grok.yaml @@ -147,7 +147,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.groq.yaml b/wren-ai-service/docs/config_examples/config.groq.yaml index 886c6ad786..a07a577e9d 100644 --- a/wren-ai-service/docs/config_examples/config.groq.yaml +++ b/wren-ai-service/docs/config_examples/config.groq.yaml @@ -146,7 +146,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.lm_studio.yaml b/wren-ai-service/docs/config_examples/config.lm_studio.yaml index 2c486caaa8..131ac04384 100644 --- a/wren-ai-service/docs/config_examples/config.lm_studio.yaml +++ b/wren-ai-service/docs/config_examples/config.lm_studio.yaml @@ -145,7 +145,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.ollama.yaml b/wren-ai-service/docs/config_examples/config.ollama.yaml index 2a5f99064a..5f8e6c4ea2 100644 --- a/wren-ai-service/docs/config_examples/config.ollama.yaml +++ b/wren-ai-service/docs/config_examples/config.ollama.yaml @@ -145,7 +145,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.open_router.yaml b/wren-ai-service/docs/config_examples/config.open_router.yaml index 347388dc8b..ecbcaa0731 100644 --- a/wren-ai-service/docs/config_examples/config.open_router.yaml +++ b/wren-ai-service/docs/config_examples/config.open_router.yaml @@ -143,7 +143,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.qwen3.yaml b/wren-ai-service/docs/config_examples/config.qwen3.yaml index e321448c1a..0ebaed162a 100644 --- a/wren-ai-service/docs/config_examples/config.qwen3.yaml +++ b/wren-ai-service/docs/config_examples/config.qwen3.yaml @@ -185,7 +185,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30 diff --git a/wren-ai-service/docs/config_examples/config.zhipu.yaml b/wren-ai-service/docs/config_examples/config.zhipu.yaml index c7c72ead1a..8d0db87d6c 100644 --- a/wren-ai-service/docs/config_examples/config.zhipu.yaml +++ b/wren-ai-service/docs/config_examples/config.zhipu.yaml @@ -193,7 +193,9 @@ pipes: llm: litellm_llm.default - name: sql_diagnosis llm: litellm_llm.default - + - name: sql_knowledge_retrieval + engine: wren_ibis + document_store: qdrant --- settings: engine_timeout: 30