Skip to content

Commit 61de6c8

Browse files
committed
tried to fix eval
1 parent eb9575e commit 61de6c8

3 files changed

Lines changed: 175 additions & 57 deletions

File tree

wren-ai-service/eval/pipelines.py

Lines changed: 144 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,35 @@
22
import os
33
import re
44
import sys
5+
import uuid
56
from abc import abstractmethod
67
from pathlib import Path
78
from typing import Any, Dict, List, Literal
89

910
import orjson
11+
import json
1012
from haystack import Document
1113
from langfuse.decorators import langfuse_context, observe
1214
from tqdm.asyncio import tqdm_asyncio
15+
from src.config import settings
16+
from src.providers import generate_components
17+
from src.web.v1.services.semantics_preparation import (
18+
SemanticsPreparationRequest,
19+
SemanticsPreparationService,
20+
)
21+
from src.web.v1.services.ask import (
22+
AskRequest,
23+
AskResultRequest,
24+
AskResultResponse,
25+
AskService,
26+
)
27+
from src.pipelines.generation import (
28+
data_assistance,
29+
intent_classification,
30+
sql_correction,
31+
sql_generation,
32+
)
33+
from src.pipelines.retrieval import historical_question, retrieval
1334

1435
sys.path.append(f"{Path().parent.resolve()}")
1536

@@ -32,15 +53,15 @@
3253
from src.core.engine import Engine
3354
from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider
3455
from src.pipelines.generation import sql_generation
35-
from src.pipelines.indexing import indexing
3656
from src.pipelines.retrieval import retrieval
57+
from src.pipelines import indexing
3758

3859

39-
def deploy_model(mdl: str, pipe: indexing.Indexing) -> None:
40-
async def wrapper():
41-
await pipe.run(orjson.dumps(mdl).decode())
60+
# def deploy_model(mdl: str, pipe: indexing.Indexing) -> None:
61+
# async def wrapper():
62+
# await pipe.run(orjson.dumps(mdl).decode())
4263

43-
asyncio.run(wrapper())
64+
# asyncio.run(wrapper())
4465

4566

4667
def extract_units(docs: list) -> list:
@@ -107,6 +128,7 @@ def split(queries: list, batch_size: int) -> list[list]:
107128
]
108129

109130
async def wrapper(batch: list):
131+
# self() will call sub-class's __call__ in every service
110132
tasks = [self(query) for query in batch]
111133
results = await tqdm_asyncio.gather(*tasks, desc="Generating Predictions")
112134
await asyncio.sleep(self._batch_interval)
@@ -188,7 +210,7 @@ def __init__(
188210
embedder_provider=embedder_provider,
189211
document_store_provider=document_store_provider,
190212
)
191-
deploy_model(mdl, _indexing)
213+
# deploy_model(mdl, _indexing)
192214

193215
self._retrieval = retrieval.Retrieval(
194216
llm_provider=llm_provider,
@@ -288,36 +310,82 @@ def mertics(
288310
}
289311

290312

313+
291314
class AskPipeline(Eval):
315+
def indexing_service(self):
316+
317+
return SemanticsPreparationService(
318+
{
319+
"db_schema": indexing.DBSchema(
320+
**self.pipe_components["db_schema_indexing"],
321+
),
322+
"historical_question": indexing.HistoricalQuestion(
323+
**self.pipe_components["historical_question_indexing"],
324+
),
325+
"table_description": indexing.TableDescription(
326+
**self.pipe_components["table_description_indexing"],
327+
),
328+
}
329+
)
330+
331+
def ask_service(self):
332+
333+
return AskService(
334+
{
335+
"intent_classification": intent_classification.IntentClassification(
336+
**self.pipe_components["intent_classification"],
337+
),
338+
"data_assistance": data_assistance.DataAssistance(
339+
**self.pipe_components["data_assistance"],
340+
),
341+
"retrieval": retrieval.Retrieval(
342+
**self.pipe_components["db_schema_retrieval"],
343+
),
344+
"historical_question": historical_question.HistoricalQuestion(
345+
**self.pipe_components["historical_question_retrieval"],
346+
),
347+
"sql_generation": sql_generation.SQLGeneration(
348+
**self.pipe_components["sql_generation"],
349+
),
350+
"sql_correction": sql_correction.SQLCorrection(
351+
**self.pipe_components["sql_correction"],
352+
),
353+
}
354+
)
355+
def dict_to_string(self, d: dict) -> str:
356+
if not isinstance(d, dict):
357+
return str(d)
358+
359+
result = "{"
360+
for key, value in d.items():
361+
result += f"'{key}': {self.dict_to_string(value)}, "
362+
result = result.rstrip(", ") + "}"
363+
return result
364+
292365
def __init__(
293366
self,
294367
meta: dict,
295368
mdl: dict,
296-
llm_provider: LLMProvider,
297-
embedder_provider: EmbedderProvider,
298-
document_store_provider: DocumentStoreProvider,
299-
engine: Engine,
300-
**kwargs,
369+
service_metadata,
370+
pipe_components,
301371
):
302372
super().__init__(meta, 3)
303-
304-
document_store_provider.get_store(recreate_index=True)
305-
_indexing = indexing.Indexing(
306-
embedder_provider=embedder_provider,
307-
document_store_provider=document_store_provider,
308-
)
309-
deploy_model(mdl, _indexing)
310-
373+
self.service_metadata = service_metadata
374+
375+
# document_store_provider.get_store(recreate_index=True)
376+
# _indexing = indexing.Indexing(
377+
# embedder_provider=embedder_provider,
378+
# document_store_provider=document_store_provider,
379+
# )
380+
# deploy_model(mdl, _indexing)
381+
self.pipe_components = pipe_components
382+
self.project_id = str(uuid.uuid4())
383+
self.indexing_service_var = self.indexing_service()
384+
self.mdl_str_var = json.dumps(mdl)
385+
self.ask_service_var = self.ask_service()
386+
self.service_metadata = service_metadata
311387
self._mdl = mdl
312-
self._retrieval = retrieval.Retrieval(
313-
llm_provider=llm_provider,
314-
embedder_provider=embedder_provider,
315-
document_store_provider=document_store_provider,
316-
)
317-
self._generation = sql_generation.SQLGeneration(
318-
llm_provider=llm_provider,
319-
engine=engine,
320-
)
388+
self.mdl_hash = str(hash(self.mdl_str_var))
321389

322390
async def _flat(self, prediction: dict, actual: str) -> dict:
323391
prediction["actual_output"] = actual
@@ -327,17 +395,54 @@ async def _flat(self, prediction: dict, actual: str) -> dict:
327395
return prediction
328396

329397
async def _process(self, prediction: dict, **_) -> dict:
330-
result = await self._retrieval.run(query=prediction["input"])
331-
documents = result.get("construct_retrieval_results", [])
332-
actual_output = await self._generation.run(
398+
399+
await self.indexing_service_var.prepare_semantics(
400+
SemanticsPreparationRequest(
401+
mdl=self.mdl_str_var,
402+
mdl_hash=self.mdl_hash,
403+
project_id=self.project_id
404+
),
405+
service_metadata=self.service_metadata,
406+
)
407+
408+
# asking
409+
ask_request = AskRequest(
333410
query=prediction["input"],
334-
contexts=documents,
335-
samples=prediction["samples"],
336-
exclude=[],
411+
mdl_hash=self.mdl_hash,
412+
project_id = self.project_id,
413+
414+
)
415+
ask_request.query_id = str(uuid.uuid4())
416+
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
417+
# getting ask result
418+
ask_result_response = self.ask_service_var.get_ask_result(
419+
AskResultRequest(
420+
query_id=ask_request.query_id,
421+
)
337422
)
338423

339-
prediction["actual_output"] = actual_output
340-
prediction["retrieval_context"] = extract_units(documents)
424+
while (
425+
ask_result_response.status != "finished"
426+
and ask_result_response.status != "failed"
427+
):
428+
# getting ask result
429+
ask_result_response = self.ask_service_var.get_ask_result(
430+
AskResultRequest(
431+
query_id=ask_request.query_id,
432+
)
433+
)
434+
435+
# result = await self._retrieval.run(query=prediction["input"])
436+
# documents = result.get("construct_retrieval_results", [])
437+
# actual_output = await self._generation.run(
438+
# query=prediction["input"],
439+
# contexts=documents,
440+
# samples=prediction["samples"],
441+
# exclude=[],
442+
# )
443+
444+
prediction["actual_output"] = ask_result_response.response[0].sql
445+
#prediction["retrieval_context"] = extract_units(documents)
341446

342447
return prediction
343448

@@ -377,9 +482,10 @@ def init(
377482
name: Literal["retrieval", "generation", "ask"],
378483
meta: dict,
379484
mdl: dict,
380-
providers: Dict[str, Any],
485+
service_metadata,
486+
pipe_components: Dict[str, Any],
381487
) -> Eval:
382-
args = {"meta": meta, "mdl": mdl, **providers}
488+
args = {"meta": meta, "mdl": mdl, "service_metadata":service_metadata,"pipe_components":pipe_components}
383489
match name:
384490
case "retrieval":
385491
return RetrievalPipeline(**args)

wren-ai-service/eval/prediction.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
from tomlkit import document, dumps
1515

1616
sys.path.append(f"{Path().parent.resolve()}")
17+
from src.config import settings
18+
from src.providers import generate_components
1719
import eval.pipelines as pipelines
18-
import src.providers as provider
1920
import src.utils as utils
2021
from eval.utils import parse_toml
2122
from src.core.engine import EngineConfig
2223
from src.core.provider import EmbedderProvider, LLMProvider
23-
24+
from src.globals import (
25+
create_service_container,
26+
create_service_metadata,
27+
)
2428

2529
def generate_meta(
2630
path: str,
@@ -46,10 +50,10 @@ def generate_meta(
4650
"commit": obtain_commit_hash(),
4751
"embedding_model": embedder_provider.get_model(),
4852
"generation_model": llm_provider.get_model(),
49-
"column_indexing_batch_size": int(os.getenv("COLUMN_INDEXING_BATCH_SIZE"))
53+
"column_indexing_batch_size": int(settings.column_indexing_batch_size)
5054
or 50,
51-
"table_retrieval_size": int(os.getenv("TABLE_RETRIEVAL_SIZE")) or 10,
52-
"table_column_retrieval_size": int(os.getenv("TABLE_COLUMN_RETRIEVAL_SIZE"))
55+
"table_retrieval_size": int(settings.table_retrieval_size) or 10,
56+
"table_column_retrieval_size": int(settings.table_column_retrieval_size)
5357
or 100,
5458
"pipeline": pipe,
5559
"batch_size": os.getenv("BATCH_SIZE") or 4,
@@ -84,11 +88,11 @@ def write_prediction(
8488
def obtain_commit_hash() -> str:
8589
repo = Repo(search_parent_directories=True)
8690

87-
if repo.untracked_files:
88-
raise Exception("There are untracked files in the repository.")
91+
# if repo.untracked_files:
92+
# raise Exception("There are untracked files in the repository.")
8993

90-
if repo.index.diff(None):
91-
raise Exception("There are uncommitted changes in the repository.")
94+
# if repo.index.diff(None):
95+
# raise Exception("There are uncommitted changes in the repository.")
9296

9397
branch = repo.active_branch
9498
return f"{repo.head.commit}@{branch.name}"
@@ -138,12 +142,12 @@ def init_providers(mdl: dict) -> dict:
138142
if engine_config is None:
139143
raise ValueError("Invalid datasource")
140144

141-
providers = provider.init_providers(engine_config=engine_config)
145+
providers_inner = provider.init_providers(engine_config=engine_config)
142146
return {
143-
"llm_provider": providers[0],
144-
"embedder_provider": providers[1],
145-
"document_store_provider": providers[2],
146-
"engine": providers[3],
147+
"llm_provider": providers_inner[0],
148+
"embedder_provider": providers_inner[1],
149+
"document_store_provider": providers_inner[2],
150+
"engine": providers_inner[3],
147151
}
148152

149153

@@ -174,23 +178,24 @@ def parse_args() -> Tuple[str]:
174178
utils.init_langfuse()
175179

176180
dataset = parse_toml(path)
177-
providers = init_providers(dataset["mdl"])
178181

182+
pipe_components = generate_components(settings.components)
179183
meta = generate_meta(
180184
path=path,
181185
dataset=dataset,
182186
pipe=pipe_name,
183-
**providers,
187+
**pipe_components["db_schema_retrieval"],
184188
)
185-
189+
service_metadata = create_service_metadata(pipe_components)
186190
pipe = pipelines.init(
187191
pipe_name,
188192
meta,
189193
mdl=dataset["mdl"],
190-
providers=providers,
194+
service_metadata=service_metadata,
195+
pipe_components=pipe_components,
191196
)
192197

193-
predictions = pipe.predict(dataset["eval_dataset"])
198+
predictions = pipe.predict([dataset["eval_dataset"][0]])
194199
meta["expected_batch_size"] = meta["query_count"] * pipe.candidate_size
195200
meta["actual_batch_size"] = len(predictions) - meta["query_count"]
196201

wren-ai-service/src/globals.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ class ServiceContainer:
6767
class ServiceMetadata:
6868
pipes_metadata: dict
6969
service_version: str
70+
def get(self, key: str):
71+
if key=="service_version":
72+
return self.service_version
73+
elif key=="pipes_metadata":
74+
return self.pipes_metadata
75+
else:
76+
return None
7077

7178

7279
def create_service_container(

0 commit comments

Comments
 (0)