22import os
33import re
44import sys
5+ import uuid
56from abc import abstractmethod
67from pathlib import Path
78from typing import Any , Dict , List , Literal
89
910import orjson
11+ import json
1012from haystack import Document
1113from langfuse .decorators import langfuse_context , observe
1214from tqdm .asyncio import tqdm_asyncio
15+ from src .config import settings
16+ from src .providers import generate_components
17+ from src .web .v1 .services .semantics_preparation import (
18+ SemanticsPreparationRequest ,
19+ SemanticsPreparationService ,
20+ )
21+ from src .web .v1 .services .ask import (
22+ AskRequest ,
23+ AskResultRequest ,
24+ AskResultResponse ,
25+ AskService ,
26+ )
27+ from src .pipelines .generation import (
28+ data_assistance ,
29+ intent_classification ,
30+ sql_correction ,
31+ sql_generation ,
32+ )
33+ from src .pipelines .retrieval import historical_question , retrieval
1334
1435sys .path .append (f"{ Path ().parent .resolve ()} " )
1536
3253from src .core .engine import Engine
3354from src .core .provider import DocumentStoreProvider , EmbedderProvider , LLMProvider
3455from src .pipelines .generation import sql_generation
35- from src .pipelines .indexing import indexing
3656from src .pipelines .retrieval import retrieval
57+ from src .pipelines import indexing
3758
3859
39- def deploy_model (mdl : str , pipe : indexing .Indexing ) -> None :
40- async def wrapper ():
41- await pipe .run (orjson .dumps (mdl ).decode ())
60+ # def deploy_model(mdl: str, pipe: indexing.Indexing) -> None:
61+ # async def wrapper():
62+ # await pipe.run(orjson.dumps(mdl).decode())
4263
43- asyncio .run (wrapper ())
64+ # asyncio.run(wrapper())
4465
4566
4667def extract_units (docs : list ) -> list :
@@ -107,6 +128,7 @@ def split(queries: list, batch_size: int) -> list[list]:
107128 ]
108129
109130 async def wrapper (batch : list ):
131+ # self() will call sub-class's __call__ in every service
110132 tasks = [self (query ) for query in batch ]
111133 results = await tqdm_asyncio .gather (* tasks , desc = "Generating Predictions" )
112134 await asyncio .sleep (self ._batch_interval )
@@ -188,7 +210,7 @@ def __init__(
188210 embedder_provider = embedder_provider ,
189211 document_store_provider = document_store_provider ,
190212 )
191- deploy_model (mdl , _indexing )
213+ # deploy_model(mdl, _indexing)
192214
193215 self ._retrieval = retrieval .Retrieval (
194216 llm_provider = llm_provider ,
@@ -288,36 +310,82 @@ def mertics(
288310 }
289311
290312
313+
291314class AskPipeline (Eval ):
315+ def indexing_service (self ):
316+
317+ return SemanticsPreparationService (
318+ {
319+ "db_schema" : indexing .DBSchema (
320+ ** self .pipe_components ["db_schema_indexing" ],
321+ ),
322+ "historical_question" : indexing .HistoricalQuestion (
323+ ** self .pipe_components ["historical_question_indexing" ],
324+ ),
325+ "table_description" : indexing .TableDescription (
326+ ** self .pipe_components ["table_description_indexing" ],
327+ ),
328+ }
329+ )
330+
331+ def ask_service (self ):
332+
333+ return AskService (
334+ {
335+ "intent_classification" : intent_classification .IntentClassification (
336+ ** self .pipe_components ["intent_classification" ],
337+ ),
338+ "data_assistance" : data_assistance .DataAssistance (
339+ ** self .pipe_components ["data_assistance" ],
340+ ),
341+ "retrieval" : retrieval .Retrieval (
342+ ** self .pipe_components ["db_schema_retrieval" ],
343+ ),
344+ "historical_question" : historical_question .HistoricalQuestion (
345+ ** self .pipe_components ["historical_question_retrieval" ],
346+ ),
347+ "sql_generation" : sql_generation .SQLGeneration (
348+ ** self .pipe_components ["sql_generation" ],
349+ ),
350+ "sql_correction" : sql_correction .SQLCorrection (
351+ ** self .pipe_components ["sql_correction" ],
352+ ),
353+ }
354+ )
355+ def dict_to_string (self , d : dict ) -> str :
356+ if not isinstance (d , dict ):
357+ return str (d )
358+
359+ result = "{"
360+ for key , value in d .items ():
361+ result += f"'{ key } ': { self .dict_to_string (value )} , "
362+ result = result .rstrip (", " ) + "}"
363+ return result
364+
292365 def __init__ (
293366 self ,
294367 meta : dict ,
295368 mdl : dict ,
296- llm_provider : LLMProvider ,
297- embedder_provider : EmbedderProvider ,
298- document_store_provider : DocumentStoreProvider ,
299- engine : Engine ,
300- ** kwargs ,
369+ service_metadata ,
370+ pipe_components ,
301371 ):
302372 super ().__init__ (meta , 3 )
303-
304- document_store_provider .get_store (recreate_index = True )
305- _indexing = indexing .Indexing (
306- embedder_provider = embedder_provider ,
307- document_store_provider = document_store_provider ,
308- )
309- deploy_model (mdl , _indexing )
310-
373+ self .service_metadata = service_metadata
374+
375+ # document_store_provider.get_store(recreate_index=True)
376+ # _indexing = indexing.Indexing(
377+ # embedder_provider=embedder_provider,
378+ # document_store_provider=document_store_provider,
379+ # )
380+ # deploy_model(mdl, _indexing)
381+ self .pipe_components = pipe_components
382+ self .project_id = str (uuid .uuid4 ())
383+ self .indexing_service_var = self .indexing_service ()
384+ self .mdl_str_var = json .dumps (mdl )
385+ self .ask_service_var = self .ask_service ()
386+ self .service_metadata = service_metadata
311387 self ._mdl = mdl
312- self ._retrieval = retrieval .Retrieval (
313- llm_provider = llm_provider ,
314- embedder_provider = embedder_provider ,
315- document_store_provider = document_store_provider ,
316- )
317- self ._generation = sql_generation .SQLGeneration (
318- llm_provider = llm_provider ,
319- engine = engine ,
320- )
388+ self .mdl_hash = str (hash (self .mdl_str_var ))
321389
322390 async def _flat (self , prediction : dict , actual : str ) -> dict :
323391 prediction ["actual_output" ] = actual
@@ -327,17 +395,54 @@ async def _flat(self, prediction: dict, actual: str) -> dict:
327395 return prediction
328396
329397 async def _process (self , prediction : dict , ** _ ) -> dict :
330- result = await self ._retrieval .run (query = prediction ["input" ])
331- documents = result .get ("construct_retrieval_results" , [])
332- actual_output = await self ._generation .run (
398+
399+ await self .indexing_service_var .prepare_semantics (
400+ SemanticsPreparationRequest (
401+ mdl = self .mdl_str_var ,
402+ mdl_hash = self .mdl_hash ,
403+ project_id = self .project_id
404+ ),
405+ service_metadata = self .service_metadata ,
406+ )
407+
408+ # asking
409+ ask_request = AskRequest (
333410 query = prediction ["input" ],
334- contexts = documents ,
335- samples = prediction ["samples" ],
336- exclude = [],
411+ mdl_hash = self .mdl_hash ,
412+ project_id = self .project_id ,
413+
414+ )
415+ ask_request .query_id = str (uuid .uuid4 ())
416+ await self .ask_service_var .ask (ask_request , service_metadata = self .service_metadata )
417+ # getting ask result
418+ ask_result_response = self .ask_service_var .get_ask_result (
419+ AskResultRequest (
420+ query_id = ask_request .query_id ,
421+ )
337422 )
338423
339- prediction ["actual_output" ] = actual_output
340- prediction ["retrieval_context" ] = extract_units (documents )
424+ while (
425+ ask_result_response .status != "finished"
426+ and ask_result_response .status != "failed"
427+ ):
428+ # getting ask result
429+ ask_result_response = self .ask_service_var .get_ask_result (
430+ AskResultRequest (
431+ query_id = ask_request .query_id ,
432+ )
433+ )
434+
435+ # result = await self._retrieval.run(query=prediction["input"])
436+ # documents = result.get("construct_retrieval_results", [])
437+ # actual_output = await self._generation.run(
438+ # query=prediction["input"],
439+ # contexts=documents,
440+ # samples=prediction["samples"],
441+ # exclude=[],
442+ # )
443+
444+ prediction ["actual_output" ] = ask_result_response .response [0 ].sql
445+ #prediction["retrieval_context"] = extract_units(documents)
341446
342447 return prediction
343448
@@ -377,9 +482,10 @@ def init(
377482 name : Literal ["retrieval" , "generation" , "ask" ],
378483 meta : dict ,
379484 mdl : dict ,
380- providers : Dict [str , Any ],
485+ service_metadata ,
486+ pipe_components : Dict [str , Any ],
381487) -> Eval :
382- args = {"meta" : meta , "mdl" : mdl , ** providers }
488+ args = {"meta" : meta , "mdl" : mdl , "service_metadata" : service_metadata , "pipe_components" : pipe_components }
383489 match name :
384490 case "retrieval" :
385491 return RetrievalPipeline (** args )
0 commit comments