Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ demo:
test test_args='': up && down
poetry run pytest -s {{test_args}} --ignore tests/pytest/test_usecases.py

test-usecases usecases='all':
poetry run python -m tests.pytest.test_usecases --usecases {{usecases}}
test-usecases usecases='all' lang='en':
poetry run python -m tests.pytest.test_usecases --usecases {{usecases}} --lang {{lang}}

load-test:
poetry run python -m tests.locust.locust_script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ class OutputFormatter:
documents=List[Optional[Dict]],
)
def run(self, documents: List[Document]):
list = []

for doc in documents:
formatted = {
list = [
{
"question": doc.content,
"summary": doc.meta.get("summary", ""),
"statement": doc.meta.get("statement") or doc.meta.get("sql"),
"viewId": doc.meta.get("viewId", ""),
"sqlpairId": doc.meta.get("sql_pair_id", ""),
}
list.append(formatted)
for doc in documents
]

return {"documents": list}

Expand Down
10 changes: 8 additions & 2 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ class StopAskResponse(BaseModel):
# GET /v1/asks/{query_id}/result
class AskResult(BaseModel):
sql: str
type: Literal["llm", "view"] = "llm"
type: Literal["llm", "view", "sql_pair"] = "llm"
viewId: Optional[str] = None
sqlpairId: Optional[str] = None


class AskError(BaseModel):
Expand Down Expand Up @@ -243,8 +244,13 @@ async def ask(
AskResult(
**{
"sql": result.get("statement"),
"type": "view",
"type": "view"
if result.get("viewId")
else "sql_pair"
if result.get("sqlpairId")
else "llm",
"viewId": result.get("viewId"),
"sqlpairId": result.get("sqlpairId"),
}
)
for result in historical_question_result
Expand Down
33 changes: 27 additions & 6 deletions wren-ai-service/tests/pytest/test_usecases.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,19 @@ def deploy_mdl(mdl_str: str, url: str):
return semantics_preperation_id


async def ask_question(question: str, url: str, semantics_preperation_id: str):
async def ask_question(
question: str, url: str, semantics_preperation_id: str, lang: str = "English"
):
print(f"preparing to ask question: {question}")
async with aiohttp.ClientSession() as session:
start = time.time()
response = await session.post(
f"{url}/v1/asks", json={"query": question, "id": semantics_preperation_id}
f"{url}/v1/asks",
json={
"query": question,
"id": semantics_preperation_id,
"configurations": {"language": lang},
},
)
assert response.status == 200

Expand All @@ -133,11 +140,13 @@ async def ask_question(question: str, url: str, semantics_preperation_id: str):
return result


async def ask_questions(questions: list[str], url: str, semantics_preperation_id: str):
async def ask_questions(
questions: list[str], url: str, semantics_preperation_id: str, lang: str = "English"
):
tasks = []
for question in questions:
task = asyncio.ensure_future(
ask_question(question, url, semantics_preperation_id)
ask_question(question, url, semantics_preperation_id, lang)
)
tasks.append(task)
await asyncio.sleep(10)
Expand All @@ -160,7 +169,7 @@ def str_presenter(dumper, data):
"woocommerce": "bigquery",
"stripe": "bigquery",
"ecommerce": "duckdb",
"hr": "duckdb",
# "hr": "duckdb",
"facebook_marketing": "bigquery",
"google_ads": "bigquery",
}
Expand All @@ -174,11 +183,23 @@ def str_presenter(dumper, data):
default=["all"],
choices=["all"] + usecases,
)
parser.add_argument(
"--lang",
type=str,
choices=["en", "zh-TW", "zh-CN"],
default="en",
)
args = parser.parse_args()

if "all" not in args.usecases:
usecases = args.usecases

lang = {
"en": "English",
"zh-TW": "Traditional Chinese",
"zh-CN": "Simplified Chinese",
}[args.lang]

url = "http://localhost:5556"

assert is_ai_service_ready(
Expand All @@ -197,7 +218,7 @@ def str_presenter(dumper, data):

# ask questions
results = asyncio.run(
ask_questions(data["questions"], url, semantics_preperation_id)
ask_questions(data["questions"], url, semantics_preperation_id, lang)
)
assert len(results) == len(data["questions"])

Expand Down
Loading