Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
st.session_state["query"] = None
if "asks_results" not in st.session_state:
st.session_state["asks_results"] = None
if "retrieved_tables" not in st.session_state:
st.session_state["retrieved_tables"] = None
if "asks_results_type" not in st.session_state:
st.session_state["asks_results_type"] = None
if "chosen_query_result" not in st.session_state:
Expand Down
28 changes: 20 additions & 8 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def on_change_sql_generation_reasoning():
]


def on_click_regenerate_sql():
def on_click_regenerate_sql(changed_sql_generation_reasoning: str):
ask_feedback(
st.session_state["sql_generation_reasoning"],
changed_sql_generation_reasoning,
st.session_state["asks_results"]["response"][0]["sql"],
)

Expand Down Expand Up @@ -223,27 +223,35 @@ def show_asks_results():
st.markdown("### Question")
st.markdown(f"{st.session_state['query']}")

st.markdown("### Retrieved Tables")
st.markdown(st.session_state["retrieved_tables"])

st.markdown("### SQL Generation Reasoning")
st.text_area(
changed_sql_generation_reasoning = st.text_area(
"SQL Generation Reasoning",
st.session_state["sql_generation_reasoning"],
key="sql_generation_reasoning_input",
height=250,
on_change=on_change_sql_generation_reasoning,
)

st.button("Regenerate SQL", on_click=on_click_regenerate_sql)
st.button(
"Regenerate SQL",
on_click=on_click_regenerate_sql,
args=(changed_sql_generation_reasoning,),
)

st.markdown("### SQL Query Result")
if st.session_state["asks_results_type"] == "TEXT_TO_SQL":
edited_sql = st.text_area(
label="",
label="SQL Query Result",
value=sqlparse.format(
st.session_state["asks_results"]["response"][0]["sql"],
reindent=True,
keyword_case="upper",
),
height=250,
label_visibility="hidden",
)
st.button(
"Save Question-SQL pair",
Expand Down Expand Up @@ -523,6 +531,7 @@ def prepare_semantics(mdl_json: dict):
st.session_state["preview_sql"] = None
st.session_state["query_history"] = None
st.session_state["sql_generation_reasoning"] = None
st.session_state["retrieved_tables"] = None

if st.session_state["semantics_preparation_status"] == "failed":
st.toast("An error occurred while preparing the semantics", icon="🚨")
Expand Down Expand Up @@ -569,12 +578,15 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None):
if asks_status == "finished":
st.session_state["asks_results_type"] = asks_type
if asks_type == "GENERAL":
display_general_response(query_id)
display_streaming_response(query_id)
elif asks_type == "TEXT_TO_SQL":
st.session_state["asks_results"] = asks_status_response.json()
st.session_state["sql_generation_reasoning"] = st.session_state[
"asks_results"
]["sql_generation_reasoning"]
st.session_state["retrieved_tables"] = ", ".join(
st.session_state["asks_results"]["retrieved_tables"]
)
else:
st.session_state["asks_results"] = asks_type
elif asks_status == "failed":
Expand Down Expand Up @@ -613,7 +625,7 @@ def ask_feedback(sql_generation_reasoning: str, sql: str):
st.toast(f"The query processing status: {ask_feedback_status}")
time.sleep(POLLING_INTERVAL)

if ask_feedback_status_response == "finished":
if ask_feedback_status == "finished":
st.session_state["asks_results_type"] = "TEXT_TO_SQL"
st.session_state["asks_results"] = ask_feedback_status_response.json()
elif ask_feedback_status == "failed":
Expand Down Expand Up @@ -661,7 +673,7 @@ def save_sql_pair(question: str, sql: str):
)


def display_general_response(query_id: str):
def display_streaming_response(query_id: str):
url = f"{WREN_AI_SERVICE_BASE_URL}/v1/asks/{query_id}/streaming-result"
headers = {"Accept": "text/event-stream"}
response = with_requests(url, headers)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import logging
import sys
from typing import Any, List, Optional

import orjson
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
Expand All @@ -28,13 +27,10 @@
5. Don't include SQL in the reasoning plan.
6. Each step in the reasoning plan must start with a number, and a reasoning for the step.
7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan.
8. Do not include ```markdown or ``` in the answer.

### FINAL ANSWER FORMAT ###
The final answer must be a reasoning plan in JSON format:

{
"reasoning_plan": <REASONING_PLAN_STRING>
}
The final answer must be a reasoning plan in plain Markdown string format
"""

sql_generation_reasoning_user_prompt_template = """
Expand Down Expand Up @@ -82,36 +78,21 @@ def prompt(


@observe(as_type="generation", capture_input=False)
async def generate_sql_reasoning(
prompt: dict,
generator: Any,
) -> dict:
return await generator(prompt=prompt.get("prompt"))
async def generate_sql_reasoning(prompt: dict, generator: Any, query_id: str) -> dict:
return await generator(prompt=prompt.get("prompt"), query_id=query_id)


@observe()
def post_process(
generate_sql_reasoning: dict,
) -> dict:
return orjson.loads(generate_sql_reasoning.get("replies")[0])
return generate_sql_reasoning.get("replies")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against empty or missing replies.
Accessing generate_sql_reasoning.get("replies")[0] can raise an IndexError if replies is empty. Add a safety check to avoid runtime failures.

-return generate_sql_reasoning.get("replies")[0]
+replies = generate_sql_reasoning.get("replies")
+if not replies:
+    return {}
+return replies[0]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return generate_sql_reasoning.get("replies")[0]
replies = generate_sql_reasoning.get("replies")
if not replies:
return {}
return replies[0]



## End of Pipeline


class SqlGenerationReasoningResult(BaseModel):
reasoning_plan: str


SQL_GENERATION_REASONING_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "sql_generation_reasoning_results",
"schema": SqlGenerationReasoningResult.model_json_schema(),
},
}
}
SQL_GENERATION_REASONING_MODEL_KWARGS = {"response_format": {"type": "text"}}


class SQLGenerationReasoning(BasicPipeline):
Expand All @@ -120,10 +101,12 @@ def __init__(
llm_provider: LLMProvider,
**kwargs,
):
self._user_queues = {}
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_reasoning_system_prompt,
generation_kwargs=SQL_GENERATION_REASONING_MODEL_KWARGS,
streaming_callback=self._streaming_callback,
),
"prompt_builder": PromptBuilder(
template=sql_generation_reasoning_user_prompt_template
Expand All @@ -134,13 +117,49 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def _streaming_callback(self, chunk, query_id):
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Create a new queue for the user if it doesn't exist
# Put the chunk content into the user's queue
asyncio.create_task(self._user_queues[query_id].put(chunk.content))
if chunk.meta.get("finish_reason"):
asyncio.create_task(self._user_queues[query_id].put("<DONE>"))

async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()

if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
self._streaming_results = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if (
self._streaming_results == "<DONE>"
): # Check for end-of-stream signal
del self._user_queues[query_id]
break
if self._streaming_results: # Check if there are results to yield
yield self._streaming_results
self._streaming_results = "" # Clear after yielding
except TimeoutError:
break

Comment on lines +130 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid storing streaming data in a shared instance property.
Using self._streaming_results can cause concurrency issues if multiple queries are handled simultaneously, as they would overwrite each other's result state. Use a local variable or wrap the logic in a dedicated queue read instead.

-            self._streaming_results = await asyncio.wait_for(
-                _get_streaming_results(query_id), timeout=120
-            )
-            if self._streaming_results == "<DONE>":
+            next_chunk = await asyncio.wait_for(
+                _get_streaming_results(query_id), timeout=120
+            )
+            if next_chunk == "<DONE>":
                 ...
-            if self._streaming_results:
-                yield self._streaming_results
-                self._streaming_results = ""
+            if next_chunk:
+                yield next_chunk
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
self._streaming_results = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if (
self._streaming_results == "<DONE>"
): # Check for end-of-stream signal
del self._user_queues[query_id]
break
if self._streaming_results: # Check if there are results to yield
yield self._streaming_results
self._streaming_results = "" # Clear after yielding
except TimeoutError:
break
async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
next_chunk = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if next_chunk == "<DONE>": # Check for end-of-stream signal
del self._user_queues[query_id]
break
if next_chunk: # Check if there are results to yield
yield next_chunk
except TimeoutError:
break

@observe(name="SQL Generation Reasoning")
async def run(
self,
query: str,
contexts: List[str],
sql_samples: Optional[List[str]] = None,
configuration: Configuration = Configuration(),
query_id: Optional[str] = None,
):
logger.info("SQL Generation Reasoning pipeline is running...")
return await self._pipe.execute(
Expand All @@ -150,6 +169,7 @@ async def run(
"documents": contexts,
"sql_samples": sql_samples or [],
"configuration": configuration,
"query_id": query_id,
**self._components,
},
)
Expand Down
45 changes: 24 additions & 21 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,14 @@ async def ask(
)["formatted_output"].get("documents", [])

sql_generation_reasoning = (
(
await self._pipelines["sql_generation_reasoning"].run(
query=user_query,
contexts=table_ddls,
sql_samples=sql_samples,
configuration=ask_request.configurations,
)
await self._pipelines["sql_generation_reasoning"].run(
query=user_query,
contexts=table_ddls,
sql_samples=sql_samples,
configuration=ask_request.configurations,
query_id=query_id,
)
.get("post_process", {})
.get("reasoning_plan")
)
).get("post_process", {})

self._ask_results[query_id] = AskResultResponse(
status="planning",
Expand Down Expand Up @@ -548,17 +545,23 @@ async def get_ask_streaming_result(
self,
query_id: str,
):
if (
self._ask_results.get(query_id)
and self._ask_results.get(query_id).type == "GENERAL"
):
async for chunk in self._pipelines["data_assistance"].get_streaming_results(
query_id
):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
if self._ask_results.get(query_id):
if self._ask_results.get(query_id).type == "GENERAL":
async for chunk in self._pipelines[
"data_assistance"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
elif self._ask_results.get(query_id).status == "planning":
async for chunk in self._pipelines[
"sql_generation_reasoning"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()

@observe(name="Ask Feedback")
@trace_metadata
Expand Down
14 changes: 5 additions & 9 deletions wren-ai-service/src/web/v1/services/question_recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,12 @@ async def _validate_question(
has_metric = _retrieval_result.get("has_metric", False)

sql_generation_reasoning = (
(
await self._pipelines["sql_generation_reasoning"].run(
query=candidate["question"],
contexts=table_ddls,
configuration=configuration,
)
await self._pipelines["sql_generation_reasoning"].run(
query=candidate["question"],
contexts=table_ddls,
configuration=configuration,
)
.get("post_process", {})
.get("reasoning_plan")
)
).get("post_process", {})

generated_sql = await self._pipelines["sql_generation"].run(
query=candidate["question"],
Expand Down
Loading