Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions wren-ai-service/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ tiktoken = "^0.8.0"
jsonschema = "^4.23.0"
litellm = "^1.60.5"
boto3 = "^1.35.90"
qdrant-client = "==1.11.0"
Comment thread
cyyeh marked this conversation as resolved.

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
Expand Down
3 changes: 2 additions & 1 deletion wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
4. Generate a concise and clear answer in string format to answerthe user's question based on the data and sql.
5. If answer is in list format, only list top few examples, and tell users there are more results omitted.
6. Answer must be in the same language user specified.
7. Do not include ```markdown or ``` in the answer.

### OUTPUT FORMAT

Please provide your response in proper Markdown format.
Please provide your response in proper Markdown stringformat.
"""

sql_to_answer_user_prompt_template = """
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/pipelines/indexing/table_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _additional_meta() -> Dict[str, Any]:
"id": str(uuid.uuid4()),
"meta": {
"type": "TABLE_DESCRIPTION",
"name": chunk["name"],
**_additional_meta(),
},
"content": str(chunk),
Expand All @@ -53,6 +54,7 @@ def _structure_data(mdl_type: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return {
"mdl_type": mdl_type,
"name": payload.get("name"),
"columns": [column["name"] for column in payload.get("columns", [])],
"properties": payload.get("properties", {}),
}

Expand All @@ -65,8 +67,8 @@ def _structure_data(mdl_type: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return [
{
"name": resource["name"],
"mdl_type": resource["mdl_type"],
"description": resource["properties"].get("description", ""),
"columns": ", ".join(resource["columns"]),
}
for resource in resources
if resource["name"] is not None
Expand Down
47 changes: 40 additions & 7 deletions wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,37 @@ def check_using_db_schemas_without_pruning(
for table_schema in construct_db_schemas:
if table_schema["type"] == "TABLE":
ddl, _has_calculated_field = build_table_ddl(table_schema)
retrieval_results.append(ddl)
retrieval_results.append(
{
"table_name": table_schema["name"],
"table_ddl": ddl,
}
)
has_calculated_field = has_calculated_field or _has_calculated_field

for document in dbschema_retrieval:
content = ast.literal_eval(document.content)

if content["type"] == "METRIC":
retrieval_results.append(_build_metric_ddl(content))
retrieval_results.append(
{
"table_name": content["name"],
"table_ddl": _build_metric_ddl(content),
}
)
has_metric = True
elif content["type"] == "VIEW":
retrieval_results.append(_build_view_ddl(content))
retrieval_results.append(
{
"table_name": content["name"],
"table_ddl": _build_view_ddl(content),
}
)

_token_count = len(encoding.encode(" ".join(retrieval_results)))
table_ddls = [
retrieval_result["table_ddl"] for retrieval_result in retrieval_results
]
_token_count = len(encoding.encode(" ".join(table_ddls)))
if _token_count > 100_000 or not allow_using_db_schemas_without_pruning:
return {
"db_schemas": [],
Expand Down Expand Up @@ -328,17 +346,32 @@ def construct_retrieval_results(
tables=tables,
)
has_calculated_field = has_calculated_field or _has_calculated_field
retrieval_results.append(ddl)
retrieval_results.append(
{
"table_name": table_schema["name"],
"table_ddl": ddl,
}
)

for document in dbschema_retrieval:
if document.meta["name"] in columns_and_tables_needed:
content = ast.literal_eval(document.content)

if content["type"] == "METRIC":
retrieval_results.append(_build_metric_ddl(content))
retrieval_results.append(
{
"table_name": content["name"],
"table_ddl": _build_metric_ddl(content),
}
)
has_metric = True
elif content["type"] == "VIEW":
retrieval_results.append(_build_view_ddl(content))
retrieval_results.append(
{
"table_name": content["name"],
"table_ddl": _build_view_ddl(content),
}
)

return {
"retrieval_results": retrieval_results,
Expand Down
17 changes: 13 additions & 4 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class AskResultResponse(BaseModel):
intent_reasoning: Optional[str] = None
sql_generation_reasoning: Optional[str] = None
type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None
retrieved_tables: Optional[List[str]] = None
response: Optional[List[AskResult]] = None
error: Optional[AskError] = None

Expand Down Expand Up @@ -310,6 +311,8 @@ async def ask(
"construct_retrieval_results", {}
)
documents = _retrieval_result.get("retrieval_results", [])
table_names = [document.get("table_name") for document in documents]
table_ddls = [document.get("table_ddl") for document in documents]

if not documents:
logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}")
Expand Down Expand Up @@ -338,6 +341,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
)

sql_samples = (
Expand All @@ -351,7 +355,7 @@ async def ask(
(
await self._pipelines["sql_generation_reasoning"].run(
query=user_query,
contexts=documents,
contexts=table_ddls,
sql_samples=sql_samples,
configuration=ask_request.configurations,
)
Expand All @@ -365,6 +369,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
)

Expand All @@ -374,6 +379,7 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
)

Expand All @@ -387,7 +393,7 @@ async def ask(
"followup_sql_generation"
].run(
query=user_query,
contexts=documents,
contexts=table_ddls,
sql_generation_reasoning=sql_generation_reasoning,
history=ask_request.history,
project_id=ask_request.project_id,
Expand All @@ -401,7 +407,7 @@ async def ask(
"sql_generation"
].run(
query=user_query,
contexts=documents,
contexts=table_ddls,
sql_generation_reasoning=sql_generation_reasoning,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
Expand Down Expand Up @@ -431,12 +437,13 @@ async def ask(
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=documents,
contexts=table_ddls,
invalid_generation_results=failed_dry_run_results,
project_id=ask_request.project_id,
)
Expand Down Expand Up @@ -468,6 +475,7 @@ async def ask(
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
)
results["ask_result"] = api_results
Expand All @@ -484,6 +492,7 @@ async def ask(
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,15 @@ async def _validate_question(
)
_retrieval_result = retrieval_result.get("construct_retrieval_results", {})
documents = _retrieval_result.get("retrieval_results", [])
table_ddls = [document.get("table_ddl") for document in documents]
has_calculated_field = _retrieval_result.get("has_calculated_field", False)
has_metric = _retrieval_result.get("has_metric", False)

sql_generation_reasoning = (
(
await self._pipelines["sql_generation_reasoning"].run(
query=candidate["question"],
contexts=documents,
contexts=table_ddls,
configuration=configuration,
)
)
Expand All @@ -93,7 +94,7 @@ async def _validate_question(

generated_sql = await self._pipelines["sql_generation"].run(
query=candidate["question"],
contexts=documents,
contexts=table_ddls,
sql_generation_reasoning=sql_generation_reasoning,
configuration=configuration,
project_id=project_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def test_single_table_description():
assert len(actual["documents"]) == 1

document: Document = actual["documents"][0]
assert document.meta == {"type": "TABLE_DESCRIPTION"}
assert document.meta == {"type": "TABLE_DESCRIPTION", "name": "user"}
assert document.content == str(
{
"name": "user",
"mdl_type": "MODEL",
"description": "A table containing user information.",
"columns": "",
}
)

Expand Down Expand Up @@ -71,22 +71,23 @@ def test_multiple_table_descriptions():
document_1: Document = actual["documents"][0]
assert document_1.meta == {
"type": "TABLE_DESCRIPTION",
"name": "user",
}
assert document_1.content == str(
{
"name": "user",
"mdl_type": "MODEL",
"description": "A table containing user information.",
"columns": "",
}
)

document_2: Document = actual["documents"][1]
assert document_2.meta == {"type": "TABLE_DESCRIPTION"}
assert document_2.meta == {"type": "TABLE_DESCRIPTION", "name": "order"}
assert document_2.content == str(
{
"name": "order",
"mdl_type": "MODEL",
"description": "A table containing order details.",
"columns": "",
}
)

Expand Down Expand Up @@ -121,10 +122,8 @@ def test_table_description_missing_description():
assert len(actual["documents"]) == 1

document: Document = actual["documents"][0]
assert document.meta == {"type": "TABLE_DESCRIPTION"}
assert document.content == str(
{"name": "user", "mdl_type": "MODEL", "description": ""}
)
assert document.meta == {"type": "TABLE_DESCRIPTION", "name": "user"}
assert document.content == str({"name": "user", "description": "", "columns": ""})


@pytest.mark.asyncio
Expand Down