Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import tiktoken
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel
Expand Down Expand Up @@ -166,12 +165,15 @@ async def table_retrieval(
@observe(capture_input=False)
async def dbschema_retrieval(
table_retrieval: dict, project_id: str, dbschema_retriever: Any
) -> list[Document]:
) -> dict:
tables = table_retrieval.get("documents", [])
table_names = []
# assign score to each table
table_scores = {}
for table in tables:
content = ast.literal_eval(table.content)
table_names.append(content["name"])
table_scores[content["name"]] = table.score

table_name_conditions = [
{"field": "name", "operator": "==", "value": table_name}
Expand All @@ -192,13 +194,17 @@ async def dbschema_retrieval(
)

results = await dbschema_retriever.run(query_embedding=[], filters=filters)
return results["documents"]

return {
"documents": results["documents"],
"table_scores": table_scores,
}


@observe()
def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[dict]:
def construct_db_schemas(dbschema_retrieval: dict) -> dict:
db_schemas = {}
for document in dbschema_retrieval:
for document in dbschema_retrieval["documents"]:
content = ast.literal_eval(document.content)
if content["type"] == "TABLE":
if document.meta["name"] not in db_schemas:
Expand All @@ -220,39 +226,48 @@ def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[dict]:
# remove incomplete schemas
db_schemas = {k: v for k, v in db_schemas.items() if "type" in v and "columns" in v}

return list(db_schemas.values())
return {
"db_schemas": list(db_schemas.values()),
"table_scores": dbschema_retrieval["table_scores"],
}


@observe(capture_input=False)
def check_using_db_schemas_without_pruning(
construct_db_schemas: list[dict],
dbschema_retrieval: list[Document],
construct_db_schemas: dict,
dbschema_retrieval: dict,
encoding: tiktoken.Encoding,
allow_using_db_schemas_without_pruning: bool,
) -> dict:
retrieval_results = []
has_calculated_field = False
has_metric = False

for table_schema in construct_db_schemas:
for table_schema in construct_db_schemas["db_schemas"]:
if table_schema["type"] == "TABLE":
ddl, _has_calculated_field = build_table_ddl(table_schema)
retrieval_results.append(
{
"table_name": table_schema["name"],
"table_ddl": ddl,
"table_score": construct_db_schemas["table_scores"][
table_schema["name"]
],
}
)
has_calculated_field = has_calculated_field or _has_calculated_field

for document in dbschema_retrieval:
for document in dbschema_retrieval["documents"]:
content = ast.literal_eval(document.content)

if content["type"] == "METRIC":
retrieval_results.append(
{
"table_name": content["name"],
"table_ddl": _build_metric_ddl(content),
"table_score": construct_db_schemas["table_scores"][
content["name"]
],
}
)
has_metric = True
Expand All @@ -261,6 +276,9 @@ def check_using_db_schemas_without_pruning(
{
"table_name": content["name"],
"table_ddl": _build_view_ddl(content),
"table_score": construct_db_schemas["table_scores"][
content["name"]
],
}
)

Expand All @@ -287,7 +305,7 @@ def check_using_db_schemas_without_pruning(
@observe(capture_input=False)
def prompt(
query: str,
construct_db_schemas: list[dict],
construct_db_schemas: dict,
prompt_builder: PromptBuilder,
check_using_db_schemas_without_pruning: dict,
histories: Optional[list[AskHistory]] = None,
Expand All @@ -298,7 +316,7 @@ def prompt(
)
db_schemas = [
build_table_ddl(construct_db_schema)[0]
for construct_db_schema in construct_db_schemas
for construct_db_schema in construct_db_schemas["db_schemas"]
]

previous_query_summaries = (
Expand All @@ -325,8 +343,8 @@ async def filter_columns_in_tables(
def construct_retrieval_results(
check_using_db_schemas_without_pruning: dict,
filter_columns_in_tables: dict,
construct_db_schemas: list[dict],
dbschema_retrieval: list[Document],
construct_db_schemas: dict,
dbschema_retrieval: dict,
) -> dict[str, Any]:
if filter_columns_in_tables:
columns_and_tables_needed = orjson.loads(
Expand All @@ -344,7 +362,7 @@ def construct_retrieval_results(
has_calculated_field = False
has_metric = False

for table_schema in construct_db_schemas:
for table_schema in construct_db_schemas["db_schemas"]:
if table_schema["type"] == "TABLE" and table_schema["name"] in tables:
ddl, _has_calculated_field = build_table_ddl(
table_schema,
Expand All @@ -361,7 +379,7 @@ def construct_retrieval_results(
}
)

for document in dbschema_retrieval:
for document in dbschema_retrieval["documents"]:
if document.meta["name"] in columns_and_tables_needed:
content = ast.literal_eval(document.content)

Expand All @@ -370,6 +388,9 @@ def construct_retrieval_results(
{
"table_name": content["name"],
"table_ddl": _build_metric_ddl(content),
"table_score": dbschema_retrieval["table_scores"][
content["name"]
],
Comment on lines +391 to +393
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify potential KeyError issue

The code accesses table scores using content["name"] as the key. Ensure that all table names in these records exist in the table_scores dictionary to avoid potential KeyError exceptions.

Also applies to: 402-404


🏁 Script executed:

#!/bin/bash

# Check for other places where table_scores is accessed to see if there's error handling
rg "table_scores\[" -A 2 -B 2

Length of output: 473


Action Required: Add Key Existence Check for table_scores Access

It appears that both at lines 391–393 and 402–404 of wren-ai-service/src/pipelines/retrieval/retrieval.py, the code directly accesses table_scores using content["name"] without any check. Since we found that the key is inserted based solely on content["name"] in another part of the file, there's a risk of a KeyError if any table name is missing from this dictionary. Please consider adding an explicit check or handling for missing keys (e.g., using .get() with a fallback or wrapping the access in a try/except block) to ensure robustness.

}
)
has_metric = True
Expand All @@ -378,6 +399,9 @@ def construct_retrieval_results(
{
"table_name": content["name"],
"table_ddl": _build_view_ddl(content),
"table_score": dbschema_retrieval["table_scores"][
content["name"]
],
}
)

Expand Down
24 changes: 15 additions & 9 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class _AskResultResponse(BaseModel):
intent_reasoning: Optional[str] = None
sql_generation_reasoning: Optional[str] = None
type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None
retrieved_tables: Optional[List[str]] = None
retrieved_tables: Optional[List[dict]] = None
response: Optional[List[AskResult]] = None
invalid_sql: Optional[str] = None
error: Optional[AskError] = None
Expand Down Expand Up @@ -219,7 +219,7 @@ async def ask(
sql_samples = []
instructions = []
api_results = []
table_names = []
retrieved_tables = []
error_message = None
invalid_sql = None

Expand Down Expand Up @@ -365,7 +365,13 @@ async def ask(
"construct_retrieval_results", {}
)
documents = _retrieval_result.get("retrieval_results", [])
table_names = [document.get("table_name") for document in documents]
retrieved_tables = [
{
"name": document.get("table_name"),
"score": document.get("table_score"),
}
for document in documents
]
table_ddls = [document.get("table_ddl") for document in documents]

if not documents:
Expand Down Expand Up @@ -397,7 +403,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
retrieved_tables=retrieved_tables,
trace_id=trace_id,
is_followup=True if histories else False,
)
Expand Down Expand Up @@ -431,7 +437,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
retrieved_tables=retrieved_tables,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
Expand All @@ -443,7 +449,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
retrieved_tables=retrieved_tables,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
Expand Down Expand Up @@ -511,7 +517,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
retrieved_tables=retrieved_tables,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
Expand Down Expand Up @@ -555,7 +561,7 @@ async def ask(
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
retrieved_tables=retrieved_tables,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
Expand All @@ -574,7 +580,7 @@ async def ask(
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
retrieved_tables=retrieved_tables,
sql_generation_reasoning=sql_generation_reasoning,
invalid_sql=invalid_sql,
trace_id=trace_id,
Expand Down
Loading