Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
st.session_state["query"] = None
if "asks_results" not in st.session_state:
st.session_state["asks_results"] = None
if "retrieved_tables" not in st.session_state:
st.session_state["retrieved_tables"] = None
if "asks_results_type" not in st.session_state:
st.session_state["asks_results_type"] = None
if "chosen_query_result" not in st.session_state:
Expand Down
28 changes: 20 additions & 8 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def on_change_sql_generation_reasoning():
]


def on_click_regenerate_sql():
def on_click_regenerate_sql(changed_sql_generation_reasoning: str):
ask_feedback(
st.session_state["sql_generation_reasoning"],
changed_sql_generation_reasoning,
st.session_state["asks_results"]["response"][0]["sql"],
)

Expand Down Expand Up @@ -223,27 +223,35 @@ def show_asks_results():
st.markdown("### Question")
st.markdown(f"{st.session_state['query']}")

st.markdown("### Retrieved Tables")
st.markdown(st.session_state["retrieved_tables"])

st.markdown("### SQL Generation Reasoning")
st.text_area(
changed_sql_generation_reasoning = st.text_area(
"SQL Generation Reasoning",
st.session_state["sql_generation_reasoning"],
key="sql_generation_reasoning_input",
height=250,
on_change=on_change_sql_generation_reasoning,
)

st.button("Regenerate SQL", on_click=on_click_regenerate_sql)
st.button(
"Regenerate SQL",
on_click=on_click_regenerate_sql,
args=(changed_sql_generation_reasoning,),
)

st.markdown("### SQL Query Result")
if st.session_state["asks_results_type"] == "TEXT_TO_SQL":
edited_sql = st.text_area(
label="",
label="SQL Query Result",
value=sqlparse.format(
st.session_state["asks_results"]["response"][0]["sql"],
reindent=True,
keyword_case="upper",
),
height=250,
label_visibility="hidden",
)
st.button(
"Save Question-SQL pair",
Expand Down Expand Up @@ -523,6 +531,7 @@ def prepare_semantics(mdl_json: dict):
st.session_state["preview_sql"] = None
st.session_state["query_history"] = None
st.session_state["sql_generation_reasoning"] = None
st.session_state["retrieved_tables"] = None

if st.session_state["semantics_preparation_status"] == "failed":
st.toast("An error occurred while preparing the semantics", icon="🚨")
Expand Down Expand Up @@ -569,12 +578,15 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None):
if asks_status == "finished":
st.session_state["asks_results_type"] = asks_type
if asks_type == "GENERAL":
display_general_response(query_id)
display_streaming_response(query_id)
elif asks_type == "TEXT_TO_SQL":
st.session_state["asks_results"] = asks_status_response.json()
st.session_state["sql_generation_reasoning"] = st.session_state[
"asks_results"
]["sql_generation_reasoning"]
st.session_state["retrieved_tables"] = ", ".join(
st.session_state["asks_results"]["retrieved_tables"]
)
else:
st.session_state["asks_results"] = asks_type
elif asks_status == "failed":
Expand Down Expand Up @@ -613,7 +625,7 @@ def ask_feedback(sql_generation_reasoning: str, sql: str):
st.toast(f"The query processing status: {ask_feedback_status}")
time.sleep(POLLING_INTERVAL)

if ask_feedback_status_response == "finished":
if ask_feedback_status == "finished":
st.session_state["asks_results_type"] = "TEXT_TO_SQL"
st.session_state["asks_results"] = ask_feedback_status_response.json()
elif ask_feedback_status == "failed":
Expand Down Expand Up @@ -661,7 +673,7 @@ def save_sql_pair(question: str, sql: str):
)


def display_general_response(query_id: str):
def display_streaming_response(query_id: str):
url = f"{WREN_AI_SERVICE_BASE_URL}/v1/asks/{query_id}/streaming-result"
headers = {"Accept": "text/event-stream"}
response = with_requests(url, headers)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import logging
import sys
from typing import Any, List, Optional

import orjson
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
Expand All @@ -28,13 +27,10 @@
5. Don't include SQL in the reasoning plan.
6. Each step in the reasoning plan must start with a number, and a reasoning for the step.
7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan.
8. Do not include ```markdown or ``` in the answer.

### FINAL ANSWER FORMAT ###
The final answer must be a reasoning plan in JSON format:

{
"reasoning_plan": <REASONING_PLAN_STRING>
}
The final answer must be a reasoning plan in plain Markdown string format
"""

sql_generation_reasoning_user_prompt_template = """
Expand Down Expand Up @@ -82,36 +78,21 @@ def prompt(


@observe(as_type="generation", capture_input=False)
async def generate_sql_reasoning(
prompt: dict,
generator: Any,
) -> dict:
return await generator(prompt=prompt.get("prompt"))
async def generate_sql_reasoning(prompt: dict, generator: Any, query_id: str) -> dict:
return await generator(prompt=prompt.get("prompt"), query_id=query_id)


@observe()
def post_process(
generate_sql_reasoning: dict,
) -> dict:
return orjson.loads(generate_sql_reasoning.get("replies")[0])
return generate_sql_reasoning.get("replies")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against empty or missing replies.
Accessing generate_sql_reasoning.get("replies")[0] can raise an IndexError if replies is empty. Add a safety check to avoid runtime failures.

-return generate_sql_reasoning.get("replies")[0]
+replies = generate_sql_reasoning.get("replies")
+if not replies:
+    return {}
+return replies[0]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return generate_sql_reasoning.get("replies")[0]
replies = generate_sql_reasoning.get("replies")
if not replies:
return {}
return replies[0]



## End of Pipeline


class SqlGenerationReasoningResult(BaseModel):
reasoning_plan: str


SQL_GENERATION_REASONING_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "sql_generation_reasoning_results",
"schema": SqlGenerationReasoningResult.model_json_schema(),
},
}
}
SQL_GENERATION_REASONING_MODEL_KWARGS = {"response_format": {"type": "text"}}


class SQLGenerationReasoning(BasicPipeline):
Expand All @@ -120,10 +101,12 @@ def __init__(
llm_provider: LLMProvider,
**kwargs,
):
self._user_queues = {}
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_reasoning_system_prompt,
generation_kwargs=SQL_GENERATION_REASONING_MODEL_KWARGS,
streaming_callback=self._streaming_callback,
),
"prompt_builder": PromptBuilder(
template=sql_generation_reasoning_user_prompt_template
Expand All @@ -134,13 +117,49 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def _streaming_callback(self, chunk, query_id):
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Create a new queue for the user if it doesn't exist
# Put the chunk content into the user's queue
asyncio.create_task(self._user_queues[query_id].put(chunk.content))
if chunk.meta.get("finish_reason"):
asyncio.create_task(self._user_queues[query_id].put("<DONE>"))

async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()

if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
self._streaming_results = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if (
self._streaming_results == "<DONE>"
): # Check for end-of-stream signal
del self._user_queues[query_id]
break
if self._streaming_results: # Check if there are results to yield
yield self._streaming_results
self._streaming_results = "" # Clear after yielding
except TimeoutError:
break

Comment on lines +130 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid storing streaming data in a shared instance property.
Using self._streaming_results can cause concurrency issues if multiple queries are handled simultaneously, as they would overwrite each other's result state. Use a local variable or wrap the logic in a dedicated queue read instead.

-            self._streaming_results = await asyncio.wait_for(
-                _get_streaming_results(query_id), timeout=120
-            )
-            if self._streaming_results == "<DONE>":
+            next_chunk = await asyncio.wait_for(
+                _get_streaming_results(query_id), timeout=120
+            )
+            if next_chunk == "<DONE>":
                 ...
-            if self._streaming_results:
-                yield self._streaming_results
-                self._streaming_results = ""
+            if next_chunk:
+                yield next_chunk
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
self._streaming_results = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if (
self._streaming_results == "<DONE>"
): # Check for end-of-stream signal
del self._user_queues[query_id]
break
if self._streaming_results: # Check if there are results to yield
yield self._streaming_results
self._streaming_results = "" # Clear after yielding
except TimeoutError:
break
async def get_streaming_results(self, query_id):
async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
next_chunk = await asyncio.wait_for(
_get_streaming_results(query_id), timeout=120
)
if next_chunk == "<DONE>": # Check for end-of-stream signal
del self._user_queues[query_id]
break
if next_chunk: # Check if there are results to yield
yield next_chunk
except TimeoutError:
break

@observe(name="SQL Generation Reasoning")
async def run(
self,
query: str,
contexts: List[str],
sql_samples: Optional[List[str]] = None,
configuration: Configuration = Configuration(),
query_id: Optional[str] = None,
):
logger.info("SQL Generation Reasoning pipeline is running...")
return await self._pipe.execute(
Expand All @@ -150,6 +169,7 @@ async def run(
"documents": contexts,
"sql_samples": sql_samples or [],
"configuration": configuration,
"query_id": query_id,
**self._components,
},
)
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/web/v1/routers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def ask(
service_container: ServiceContainer = Depends(get_service_container),
service_metadata: ServiceMetadata = Depends(get_service_metadata),
) -> AskResponse:
query_id = str(uuid.uuid4())
query_id = "c1b011f4-1360-43e0-9c20-22c0c6025206" # str(uuid.uuid4())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Using a static UUID can cause collisions.
By hardcoding query_id, all new queries share the same ID, which breaks traceability and concurrency. Return to a dynamic UUID or otherwise ensure uniqueness if you need per-request separation.

-    query_id = "c1b011f4-1360-43e0-9c20-22c0c6025206"  # str(uuid.uuid4())
+    query_id = str(uuid.uuid4())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
query_id = "c1b011f4-1360-43e0-9c20-22c0c6025206" # str(uuid.uuid4())
query_id = str(uuid.uuid4())

ask_request.query_id = query_id
service_container.ask_service._ask_results[query_id] = AskResultResponse(
status="understanding",
Expand Down
45 changes: 24 additions & 21 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,14 @@ async def ask(
)["formatted_output"].get("documents", [])

sql_generation_reasoning = (
(
await self._pipelines["sql_generation_reasoning"].run(
query=user_query,
contexts=table_ddls,
sql_samples=sql_samples,
configuration=ask_request.configurations,
)
await self._pipelines["sql_generation_reasoning"].run(
query=user_query,
contexts=table_ddls,
sql_samples=sql_samples,
configuration=ask_request.configurations,
query_id=query_id,
)
.get("post_process", {})
.get("reasoning_plan")
)
).get("post_process", {})

self._ask_results[query_id] = AskResultResponse(
status="planning",
Expand Down Expand Up @@ -548,17 +545,23 @@ async def get_ask_streaming_result(
self,
query_id: str,
):
if (
self._ask_results.get(query_id)
and self._ask_results.get(query_id).type == "GENERAL"
):
async for chunk in self._pipelines["data_assistance"].get_streaming_results(
query_id
):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
if self._ask_results.get(query_id):
if self._ask_results.get(query_id).type == "GENERAL":
async for chunk in self._pipelines[
"data_assistance"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
elif self._ask_results.get(query_id).status == "planning":
async for chunk in self._pipelines[
"sql_generation_reasoning"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()

@observe(name="Ask Feedback")
@trace_metadata
Expand Down
14 changes: 5 additions & 9 deletions wren-ai-service/src/web/v1/services/question_recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,12 @@ async def _validate_question(
has_metric = _retrieval_result.get("has_metric", False)

sql_generation_reasoning = (
(
await self._pipelines["sql_generation_reasoning"].run(
query=candidate["question"],
contexts=table_ddls,
configuration=configuration,
)
await self._pipelines["sql_generation_reasoning"].run(
query=candidate["question"],
contexts=table_ddls,
configuration=configuration,
)
.get("post_process", {})
.get("reasoning_plan")
)
).get("post_process", {})

generated_sql = await self._pipelines["sql_generation"].run(
query=candidate["question"],
Expand Down
Loading