From 8096a49100b9b31179357142cb8ffa852b53fd61 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 13 Feb 2025 13:58:54 +0800 Subject: [PATCH 1/4] add retrieval tables and fix regenerating sql --- wren-ai-service/demo/app.py | 2 ++ wren-ai-service/demo/utils.py | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/wren-ai-service/demo/app.py b/wren-ai-service/demo/app.py index a5c31c9cf4..1e1978b290 100644 --- a/wren-ai-service/demo/app.py +++ b/wren-ai-service/demo/app.py @@ -40,6 +40,8 @@ st.session_state["query"] = None if "asks_results" not in st.session_state: st.session_state["asks_results"] = None +if "retrieved_tables" not in st.session_state: + st.session_state["retrieved_tables"] = None if "asks_results_type" not in st.session_state: st.session_state["asks_results_type"] = None if "chosen_query_result" not in st.session_state: diff --git a/wren-ai-service/demo/utils.py b/wren-ai-service/demo/utils.py index 3ee0827657..82e1f60292 100644 --- a/wren-ai-service/demo/utils.py +++ b/wren-ai-service/demo/utils.py @@ -181,9 +181,9 @@ def on_change_sql_generation_reasoning(): ] -def on_click_regenerate_sql(): +def on_click_regenerate_sql(changed_sql_generation_reasoning: str): ask_feedback( - st.session_state["sql_generation_reasoning"], + changed_sql_generation_reasoning, st.session_state["asks_results"]["response"][0]["sql"], ) @@ -223,8 +223,11 @@ def show_asks_results(): st.markdown("### Question") st.markdown(f"{st.session_state['query']}") + st.markdown("### Retrieved Tables") + st.markdown(st.session_state["retrieved_tables"]) + st.markdown("### SQL Generation Reasoning") - st.text_area( + changed_sql_generation_reasoning = st.text_area( "SQL Generation Reasoning", st.session_state["sql_generation_reasoning"], key="sql_generation_reasoning_input", @@ -232,18 +235,23 @@ def show_asks_results(): on_change=on_change_sql_generation_reasoning, ) - st.button("Regenerate SQL", on_click=on_click_regenerate_sql) + st.button( + "Regenerate SQL", + on_click=on_click_regenerate_sql, + args=(changed_sql_generation_reasoning,), + ) st.markdown("### SQL Query Result") if st.session_state["asks_results_type"] == "TEXT_TO_SQL": edited_sql = st.text_area( - label="", + label="SQL Query Result", value=sqlparse.format( st.session_state["asks_results"]["response"][0]["sql"], reindent=True, keyword_case="upper", ), height=250, + label_visibility="hidden", ) st.button( "Save Question-SQL pair", @@ -523,6 +531,7 @@ def prepare_semantics(mdl_json: dict): st.session_state["preview_sql"] = None st.session_state["query_history"] = None st.session_state["sql_generation_reasoning"] = None + st.session_state["retrieved_tables"] = None if st.session_state["semantics_preparation_status"] == "failed": st.toast("An error occurred while preparing the semantics", icon="🚨") @@ -575,6 +584,9 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None): st.session_state["sql_generation_reasoning"] = st.session_state[ "asks_results" ]["sql_generation_reasoning"] + st.session_state["retrieved_tables"] = ", ".join( + st.session_state["asks_results"]["retrieved_tables"] + ) else: st.session_state["asks_results"] = asks_type elif asks_status == "failed": @@ -613,7 +625,7 @@ def ask_feedback(sql_generation_reasoning: str, sql: str): st.toast(f"The query processing status: {ask_feedback_status}") time.sleep(POLLING_INTERVAL) - if ask_feedback_status_response == "finished": + if ask_feedback_status == "finished": st.session_state["asks_results_type"] = "TEXT_TO_SQL" st.session_state["asks_results"] = ask_feedback_status_response.json() elif ask_feedback_status == "failed": From b42b5c779d550ef78a018b331d8be49bfc32f6d8 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 13 Feb 2025 15:56:57 +0800 Subject: [PATCH 2/4] allow streaming for sql generation reasoning --- wren-ai-service/demo/utils.py | 4 +- .../generation/sql_generation_reasoning.py | 72 ++++++++++++------- wren-ai-service/src/web/v1/routers/ask.py | 2 +- wren-ai-service/src/web/v1/services/ask.py | 45 ++++++------ 4 files changed, 73 insertions(+), 50 deletions(-) diff --git a/wren-ai-service/demo/utils.py b/wren-ai-service/demo/utils.py index 82e1f60292..cc44b3d0f0 100644 --- a/wren-ai-service/demo/utils.py +++ b/wren-ai-service/demo/utils.py @@ -578,7 +578,7 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None): if asks_status == "finished": st.session_state["asks_results_type"] = asks_type if asks_type == "GENERAL": - display_general_response(query_id) + display_streaming_response(query_id) elif asks_type == "TEXT_TO_SQL": st.session_state["asks_results"] = asks_status_response.json() st.session_state["sql_generation_reasoning"] = st.session_state[ @@ -673,7 +673,7 @@ def save_sql_pair(question: str, sql: str): ) -def display_general_response(query_id: str): +def display_streaming_response(query_id: str): url = f"{WREN_AI_SERVICE_BASE_URL}/v1/asks/{query_id}/streaming-result" headers = {"Accept": "text/event-stream"} response = with_requests(url, headers) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 8829a9cf03..f11a36053e 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -1,13 +1,12 @@ +import asyncio import logging import sys from typing import Any, List, Optional -import orjson from hamilton import base from hamilton.async_driver import AsyncDriver from haystack.components.builders.prompt_builder import PromptBuilder from langfuse.decorators import observe -from pydantic import BaseModel from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider @@ -28,13 +27,10 @@ 5. Don't include SQL in the reasoning plan. 6. Each step in the reasoning plan must start with a number, and a reasoning for the step. 7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan. +8. Do not include ```markdown or ``` in the answer. ### FINAL ANSWER FORMAT ### -The final answer must be a reasoning plan in JSON format: - -{ - "reasoning_plan": -} +The final answer must be a reasoning plan in plain Markdown string format """ sql_generation_reasoning_user_prompt_template = """ @@ -82,36 +78,21 @@ def prompt( @observe(as_type="generation", capture_input=False) -async def generate_sql_reasoning( - prompt: dict, - generator: Any, -) -> dict: - return await generator(prompt=prompt.get("prompt")) +async def generate_sql_reasoning(prompt: dict, generator: Any, query_id: str) -> dict: + return await generator(prompt=prompt.get("prompt"), query_id=query_id) @observe() def post_process( generate_sql_reasoning: dict, ) -> dict: - return orjson.loads(generate_sql_reasoning.get("replies")[0]) + return generate_sql_reasoning.get("replies")[0] ## End of Pipeline -class SqlGenerationReasoningResult(BaseModel): - reasoning_plan: str - - -SQL_GENERATION_REASONING_MODEL_KWARGS = { - "response_format": { - "type": "json_schema", - "json_schema": { - "name": "sql_generation_reasoning_results", - "schema": SqlGenerationReasoningResult.model_json_schema(), - }, - } -} +SQL_GENERATION_REASONING_MODEL_KWARGS = {"response_format": {"type": "text"}} class SQLGenerationReasoning(BasicPipeline): @@ -120,10 +101,12 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): + self._user_queues = {} self._components = { "generator": llm_provider.get_generator( system_prompt=sql_generation_reasoning_system_prompt, generation_kwargs=SQL_GENERATION_REASONING_MODEL_KWARGS, + streaming_callback=self._streaming_callback, ), "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template @@ -134,6 +117,41 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _streaming_callback(self, chunk, query_id): + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Create a new queue for the user if it doesn't exist + # Put the chunk content into the user's queue + asyncio.create_task(self._user_queues[query_id].put(chunk.content)) + if chunk.meta.get("finish_reason"): + asyncio.create_task(self._user_queues[query_id].put("")) + + async def get_streaming_results(self, query_id): + async def _get_streaming_results(query_id): + return await self._user_queues[query_id].get() + + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Ensure the user's queue exists + while True: + try: + # Wait for an item from the user's queue + self._streaming_results = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if ( + self._streaming_results == "" + ): # Check for end-of-stream signal + del self._user_queues[query_id] + break + if self._streaming_results: # Check if there are results to yield + yield self._streaming_results + self._streaming_results = "" # Clear after yielding + except TimeoutError: + break + @observe(name="SQL Generation Reasoning") async def run( self, @@ -141,6 +159,7 @@ async def run( contexts: List[str], sql_samples: Optional[List[str]] = None, configuration: Configuration = Configuration(), + query_id: Optional[str] = None, ): logger.info("SQL Generation Reasoning pipeline is running...") return await self._pipe.execute( @@ -150,6 +169,7 @@ async def run( "documents": contexts, "sql_samples": sql_samples or [], "configuration": configuration, + "query_id": query_id, **self._components, }, ) diff --git a/wren-ai-service/src/web/v1/routers/ask.py b/wren-ai-service/src/web/v1/routers/ask.py index 7998787e68..2b3181d0f0 100644 --- a/wren-ai-service/src/web/v1/routers/ask.py +++ b/wren-ai-service/src/web/v1/routers/ask.py @@ -93,7 +93,7 @@ async def ask( service_container: ServiceContainer = Depends(get_service_container), service_metadata: ServiceMetadata = Depends(get_service_metadata), ) -> AskResponse: - query_id = str(uuid.uuid4()) + query_id = "c1b011f4-1360-43e0-9c20-22c0c6025206" # str(uuid.uuid4()) ask_request.query_id = query_id service_container.ask_service._ask_results[query_id] = AskResultResponse( status="understanding", diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index b3f427ddeb..734976e115 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -352,17 +352,14 @@ async def ask( )["formatted_output"].get("documents", []) sql_generation_reasoning = ( - ( - await self._pipelines["sql_generation_reasoning"].run( - query=user_query, - contexts=table_ddls, - sql_samples=sql_samples, - configuration=ask_request.configurations, - ) + await self._pipelines["sql_generation_reasoning"].run( + query=user_query, + contexts=table_ddls, + sql_samples=sql_samples, + configuration=ask_request.configurations, + query_id=query_id, ) - .get("post_process", {}) - .get("reasoning_plan") - ) + ).get("post_process", {}) self._ask_results[query_id] = AskResultResponse( status="planning", @@ -548,17 +545,23 @@ async def get_ask_streaming_result( self, query_id: str, ): - if ( - self._ask_results.get(query_id) - and self._ask_results.get(query_id).type == "GENERAL" - ): - async for chunk in self._pipelines["data_assistance"].get_streaming_results( - query_id - ): - event = SSEEvent( - data=SSEEvent.SSEEventMessage(message=chunk), - ) - yield event.serialize() + if self._ask_results.get(query_id): + if self._ask_results.get(query_id).type == "GENERAL": + async for chunk in self._pipelines[ + "data_assistance" + ].get_streaming_results(query_id): + event = SSEEvent( + data=SSEEvent.SSEEventMessage(message=chunk), + ) + yield event.serialize() + elif self._ask_results.get(query_id).status == "planning": + async for chunk in self._pipelines[ + "sql_generation_reasoning" + ].get_streaming_results(query_id): + event = SSEEvent( + data=SSEEvent.SSEEventMessage(message=chunk), + ) + yield event.serialize() @observe(name="Ask Feedback") @trace_metadata From 70ee22bc584356f53b31afc9e42be172215db20f Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 13 Feb 2025 16:12:46 +0800 Subject: [PATCH 3/4] fix bug --- .../src/web/v1/services/question_recommendation.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index e4524fd501..11dee11040 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -81,16 +81,12 @@ async def _validate_question( has_metric = _retrieval_result.get("has_metric", False) sql_generation_reasoning = ( - ( - await self._pipelines["sql_generation_reasoning"].run( - query=candidate["question"], - contexts=table_ddls, - configuration=configuration, - ) + await self._pipelines["sql_generation_reasoning"].run( + query=candidate["question"], + contexts=table_ddls, + configuration=configuration, ) - .get("post_process", {}) - .get("reasoning_plan") - ) + ).get("post_process", {}) generated_sql = await self._pipelines["sql_generation"].run( query=candidate["question"], From 73b4c0b9d008fc139a32b08fc4090450cc51c452 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 13 Feb 2025 16:21:03 +0800 Subject: [PATCH 4/4] fix --- wren-ai-service/src/web/v1/routers/ask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/src/web/v1/routers/ask.py b/wren-ai-service/src/web/v1/routers/ask.py index 2b3181d0f0..7998787e68 100644 --- a/wren-ai-service/src/web/v1/routers/ask.py +++ b/wren-ai-service/src/web/v1/routers/ask.py @@ -93,7 +93,7 @@ async def ask( service_container: ServiceContainer = Depends(get_service_container), service_metadata: ServiceMetadata = Depends(get_service_metadata), ) -> AskResponse: - query_id = "c1b011f4-1360-43e0-9c20-22c0c6025206" # str(uuid.uuid4()) + query_id = str(uuid.uuid4()) ask_request.query_id = query_id service_container.ask_service._ask_results[query_id] = AskResultResponse( status="understanding",