diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 3513c93a41..72b0678bc9 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -158,13 +158,8 @@ data: llm: litellm_llm.default - name: relationship_recommendation llm: litellm_llm.default - engine: wren_ui - name: question_recommendation llm: litellm_llm.default - - name: question_recommendation_db_schema_retrieval - llm: litellm_llm.default - embedder: litellm_embedder.default - document_store: qdrant - name: question_recommendation_sql_generation llm: litellm_llm.default engine: wren_ui diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 6a066d1eee..f8f56c7bb3 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -108,13 +108,8 @@ pipes: llm: litellm_llm.default - name: relationship_recommendation llm: litellm_llm.default - engine: wren_ui - name: question_recommendation llm: litellm_llm.default - - name: question_recommendation_db_schema_retrieval - llm: litellm_llm.default - embedder: litellm_embedder.default - document_store: qdrant - name: question_recommendation_sql_generation llm: litellm_llm.default engine: wren_ui diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index cae34c40a2..728d835c91 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -16,6 +16,7 @@ @dataclass class ServiceContainer: ask_service: services.AskService + ask_feedback_service: services.AskFeedbackService question_recommendation: services.QuestionRecommendation relationship_recommendation: services.RelationshipRecommendation semantics_description: services.SemanticsDescription @@ -47,6 +48,38 @@ def create_service_container( if not wren_ai_docs: logger.warning("Failed to fetch Wren AI docs or response was empty.") + _db_schema_retrieval_pipeline = retrieval.DbSchemaRetrieval( + **pipe_components["db_schema_retrieval"], + table_retrieval_size=settings.table_retrieval_size, + table_column_retrieval_size=settings.table_column_retrieval_size, + ) + _sql_pair_indexing_pipeline = indexing.SqlPairs( + **pipe_components["sql_pairs_indexing"], + sql_pairs_path=settings.sql_pairs_path, + ) + _instructions_indexing_pipeline = indexing.Instructions( + **pipe_components["instructions_indexing"], + ) + _sql_pair_retrieval_pipeline = retrieval.SqlPairsRetrieval( + **pipe_components["sql_pairs_retrieval"], + sql_pairs_similarity_threshold=settings.sql_pairs_similarity_threshold, + sql_pairs_retrieval_max_size=settings.sql_pairs_retrieval_max_size, + ) + _instructions_retrieval_pipeline = retrieval.Instructions( + **pipe_components["instructions_retrieval"], + similarity_threshold=settings.instructions_similarity_threshold, + top_k=settings.instructions_top_k, + ) + _sql_correction_pipeline = generation.SQLCorrection( + **pipe_components["sql_correction"], + ) + _sql_functions_retrieval_pipeline = retrieval.SqlFunctions( + **pipe_components["sql_functions_retrieval"], + ) + _sql_executor_pipeline = retrieval.SQLExecutor( + **pipe_components["sql_executor"], + ) + return ServiceContainer( semantics_description=services.SemanticsDescription( pipelines={ @@ -68,13 +101,8 @@ def create_service_container( "table_description": indexing.TableDescription( **pipe_components["table_description_indexing"], ), - "sql_pairs": indexing.SqlPairs( - **pipe_components["sql_pairs_indexing"], - sql_pairs_path=settings.sql_pairs_path, - ), - "instructions": indexing.Instructions( - **pipe_components["instructions_indexing"], - ), + "sql_pairs": _sql_pair_indexing_pipeline, + "instructions": _instructions_indexing_pipeline, "project_meta": indexing.ProjectMeta( **pipe_components["project_meta_indexing"], ), @@ -97,28 +125,15 @@ def create_service_container( **pipe_components["user_guide_assistance"], wren_ai_docs=wren_ai_docs, ), - "db_schema_retrieval": retrieval.DbSchemaRetrieval( - **pipe_components["db_schema_retrieval"], - table_retrieval_size=settings.table_retrieval_size, - table_column_retrieval_size=settings.table_column_retrieval_size, - ), + "db_schema_retrieval": _db_schema_retrieval_pipeline, "historical_question": retrieval.HistoricalQuestionRetrieval( **pipe_components["historical_question_retrieval"], historical_question_retrieval_similarity_threshold=settings.historical_question_retrieval_similarity_threshold, ), - "sql_pairs_retrieval": retrieval.SqlPairsRetrieval( - **pipe_components["sql_pairs_retrieval"], - sql_pairs_similarity_threshold=settings.sql_pairs_similarity_threshold, - sql_pairs_retrieval_max_size=settings.sql_pairs_retrieval_max_size, - ), - "instructions_retrieval": retrieval.Instructions( - **pipe_components["instructions_retrieval"], - similarity_threshold=settings.instructions_similarity_threshold, - top_k=settings.instructions_top_k, - ), + "sql_pairs_retrieval": _sql_pair_retrieval_pipeline, + "instructions_retrieval": _instructions_retrieval_pipeline, "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], - engine_timeout=settings.engine_timeout, ), "sql_generation_reasoning": generation.SQLGenerationReasoning( **pipe_components["sql_generation_reasoning"], @@ -126,22 +141,11 @@ def create_service_container( "followup_sql_generation_reasoning": generation.FollowUpSQLGenerationReasoning( **pipe_components["followup_sql_generation_reasoning"], ), - "sql_correction": generation.SQLCorrection( - **pipe_components["sql_correction"], - engine_timeout=settings.engine_timeout, - ), + "sql_correction": _sql_correction_pipeline, "followup_sql_generation": generation.FollowUpSQLGeneration( **pipe_components["followup_sql_generation"], - engine_timeout=settings.engine_timeout, - ), - "sql_regeneration": generation.SQLRegeneration( - **pipe_components["sql_regeneration"], - engine_timeout=settings.engine_timeout, - ), - "sql_functions_retrieval": retrieval.SqlFunctions( - **pipe_components["sql_functions_retrieval"], - engine_timeout=settings.engine_timeout, ), + "sql_functions_retrieval": _sql_functions_retrieval_pipeline, }, allow_intent_classification=settings.allow_intent_classification, allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning, @@ -151,12 +155,23 @@ def create_service_container( max_sql_correction_retries=settings.max_sql_correction_retries, **query_cache, ), - chart_service=services.ChartService( + ask_feedback_service=services.AskFeedbackService( pipelines={ - "sql_executor": retrieval.SQLExecutor( - **pipe_components["sql_executor"], - engine_timeout=settings.engine_timeout, + "db_schema_retrieval": _db_schema_retrieval_pipeline, + "sql_pairs_retrieval": _sql_pair_retrieval_pipeline, + "instructions_retrieval": _instructions_retrieval_pipeline, + "sql_functions_retrieval": _sql_functions_retrieval_pipeline, + "sql_regeneration": generation.SQLRegeneration( + **pipe_components["sql_regeneration"], ), + "sql_correction": _sql_correction_pipeline, + }, + allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval, + **query_cache, + ), + chart_service=services.ChartService( + pipelines={ + "sql_executor": _sql_executor_pipeline, "chart_generation": generation.ChartGeneration( **pipe_components["chart_generation"], ), @@ -165,10 +180,7 @@ def create_service_container( ), chart_adjustment_service=services.ChartAdjustmentService( pipelines={ - "sql_executor": retrieval.SQLExecutor( - **pipe_components["sql_executor"], - engine_timeout=settings.engine_timeout, - ), + "sql_executor": _sql_executor_pipeline, "chart_adjustment": generation.ChartAdjustment( **pipe_components["chart_adjustment"], ), @@ -182,7 +194,6 @@ def create_service_container( ), "sql_answer": generation.SQLAnswer( **pipe_components["sql_answer"], - engine_timeout=settings.engine_timeout, ), }, **query_cache, @@ -191,7 +202,6 @@ def create_service_container( pipelines={ "relationship_recommendation": generation.RelationshipRecommendation( **pipe_components["relationship_recommendation"], - engine_timeout=settings.engine_timeout, ) }, **query_cache, @@ -201,39 +211,20 @@ def create_service_container( "question_recommendation": generation.QuestionRecommendation( **pipe_components["question_recommendation"], ), - "db_schema_retrieval": retrieval.DbSchemaRetrieval( - **pipe_components["question_recommendation_db_schema_retrieval"], - table_retrieval_size=settings.table_retrieval_size, - table_column_retrieval_size=settings.table_column_retrieval_size, - ), + "db_schema_retrieval": _db_schema_retrieval_pipeline, "sql_generation": generation.SQLGeneration( **pipe_components["question_recommendation_sql_generation"], - engine_timeout=settings.engine_timeout, - ), - "sql_pairs_retrieval": retrieval.SqlPairsRetrieval( - **pipe_components["sql_pairs_retrieval"], - sql_pairs_similarity_threshold=settings.sql_pairs_similarity_threshold, - sql_pairs_retrieval_max_size=settings.sql_pairs_retrieval_max_size, - ), - "instructions_retrieval": retrieval.Instructions( - **pipe_components["instructions_retrieval"], - similarity_threshold=settings.instructions_similarity_threshold, - top_k=settings.instructions_top_k, - ), - "sql_functions_retrieval": retrieval.SqlFunctions( - **pipe_components["sql_functions_retrieval"], - engine_timeout=settings.engine_timeout, ), + "sql_pairs_retrieval": _sql_pair_retrieval_pipeline, + "instructions_retrieval": _instructions_retrieval_pipeline, + "sql_functions_retrieval": _sql_functions_retrieval_pipeline, }, allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval, **query_cache, ), sql_pairs_service=services.SqlPairsService( pipelines={ - "sql_pairs": indexing.SqlPairs( - **pipe_components["sql_pairs_indexing"], - sql_pairs_path=settings.sql_pairs_path, - ) + "sql_pairs": _sql_pair_indexing_pipeline, }, **query_cache, ), @@ -247,9 +238,7 @@ def create_service_container( ), instructions_service=services.InstructionsService( pipelines={ - "instructions_indexing": indexing.Instructions( - **pipe_components["instructions_indexing"], - ) + "instructions_indexing": _instructions_indexing_pipeline, }, **query_cache, ), @@ -258,15 +247,8 @@ def create_service_container( "sql_tables_extraction": generation.SQLTablesExtraction( **pipe_components["sql_tables_extraction"], ), - "db_schema_retrieval": retrieval.DbSchemaRetrieval( - **pipe_components["db_schema_retrieval"], - table_retrieval_size=settings.table_retrieval_size, - table_column_retrieval_size=settings.table_column_retrieval_size, - ), - "sql_correction": generation.SQLCorrection( - **pipe_components["sql_correction"], - engine_timeout=settings.engine_timeout, - ), + "db_schema_retrieval": _db_schema_retrieval_pipeline, + "sql_correction": _sql_correction_pipeline, }, **query_cache, ), diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 519307e42f..5f889ff00c 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -131,7 +131,6 @@ async def generate_sql_in_followup( async def post_process( generate_sql_in_followup: dict, post_processor: SQLGenPostProcessor, - engine_timeout: float, data_source: str, project_id: str | None = None, use_dry_plan: bool = False, @@ -139,7 +138,6 @@ async def post_process( ) -> dict: return await post_processor.run( generate_sql_in_followup.get("replies"), - timeout=engine_timeout, project_id=project_id, use_dry_plan=use_dry_plan, data_source=data_source, @@ -156,7 +154,6 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, - engine_timeout: float = 30.0, **kwargs, ): self._retriever = document_store_provider.get_retriever( @@ -175,10 +172,6 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } - self._configs = { - "engine_timeout": engine_timeout, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -225,6 +218,5 @@ async def run( "allow_dry_plan_fallback": allow_dry_plan_fallback, "data_source": metadata.get("data_source", "local_file"), **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/question_recommendation.py b/wren-ai-service/src/pipelines/generation/question_recommendation.py index 9787a0e065..a6e7c17b02 100644 --- a/wren-ai-service/src/pipelines/generation/question_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/question_recommendation.py @@ -17,76 +17,6 @@ logger = logging.getLogger("wren-ai-service") -## Start of Pipeline -@observe(capture_input=False) -def prompt( - mdl: dict, - previous_questions: list[str], - language: str, - max_questions: int, - max_categories: int, - prompt_builder: PromptBuilder, -) -> dict: - """ - If previous_questions is provided, the MDL is omitted to allow the LLM to focus on - generating recommendations based on the question history. This helps provide more - contextually relevant questions that build on previous questions. - """ - - _prompt = prompt_builder.run( - models=[] if previous_questions else mdl.get("models", []), - previous_questions=previous_questions, - language=language, - max_questions=max_questions, - max_categories=max_categories, - ) - return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} - - -@observe(as_type="generation", capture_input=False) -@trace_cost -async def generate(prompt: dict, generator: Any, generator_name: str) -> dict: - return await generator(prompt=prompt.get("prompt")), generator_name - - -@observe(capture_input=False) -def normalized(generate: dict) -> dict: - def wrapper(text: str) -> list: - text = text.replace("\n", " ") - text = " ".join(text.split()) - try: - text_list = orjson.loads(text.strip()) - return text_list - except orjson.JSONDecodeError as e: - logger.error(f"Error decoding JSON: {e}") - return [] # Return an empty list if JSON decoding fails - - reply = generate.get("replies")[0] # Expecting only one reply - normalized = wrapper(reply) - - return normalized - - -## End of Pipeline -class Question(BaseModel): - question: str - category: str - - -class QuestionResult(BaseModel): - questions: list[Question] - - -QUESTION_RECOMMENDATION_MODEL_KWARGS = { - "response_format": { - "type": "json_schema", - "json_schema": { - "name": "question_recommendation", - "schema": QuestionResult.model_json_schema(), - }, - } -} - system_prompt = """ You are an expert in data analysis and SQL query generation. Given a data model specification, optionally a user's question, and a list of categories, your task is to generate insightful, specific questions that can be answered using the provided data model. Each question should be accompanied by a brief explanation of its relevance or importance. @@ -125,7 +55,7 @@ class QuestionResult(BaseModel): 3. **If a User Question is Provided:** - - Generate questions that are closely related to the user’s previous question, ensuring that the new questions build upon or provide deeper insights into the original query. + - Generate questions that are closely related to the user's previous question, ensuring that the new questions build upon or provide deeper insights into the original query. - Use **random category selection** to introduce diverse perspectives while maintaining a focus on the context of the previous question. - Apply the analysis techniques above to enhance the relevance and depth of the generated questions. @@ -137,7 +67,7 @@ class QuestionResult(BaseModel): 5. **General Guidelines for All Questions:** - Ensure questions can be answered using the data model. - Mix simple and complex questions. - - Avoid open-ended questions – each should have a definite answer. + - Avoid open-ended questions - each should have a definite answer. - Incorporate time-based analysis where relevant. - Combine multiple analysis techniques when appropriate for deeper insights. @@ -198,7 +128,7 @@ class QuestionResult(BaseModel): Ensure that categories are selected in a random order for each question generation session. - **Avoid Repetition:** - Ensure the same category doesn’t dominate the list by limiting the number of questions from any single category unless specified otherwise. + Ensure the same category doesn't dominate the list by limiting the number of questions from any single category unless specified otherwise. - **Diversity of Analysis:** Combine different analysis techniques (drill-down, roll-up, etc.) within the selected categories for richer insights. @@ -210,10 +140,6 @@ class QuestionResult(BaseModel): """ user_prompt_template = """ -{% if models %} -Data Model Specification: -{{models}} -{% endif %} {% if previous_questions %} Previous Questions: {{previous_questions}} @@ -223,10 +149,88 @@ class QuestionResult(BaseModel): Categories: {{categories}} {% endif %} +{% if documents %} +### DATABASE SCHEMA ### +{% for document in documents %} + {{ document }} +{% endfor %} +{% endif %} + Please generate {{max_questions}} insightful questions for each of the {{max_categories}} categories based on the provided data model. Both the questions and category names should be translated into {{language}}{% if user_question %} and be related to the user's question{% endif %}. The output format should maintain the structure but with localized text. """ +## Start of Pipeline +@observe(capture_input=False) +def prompt( + previous_questions: list[str], + documents: list, + language: str, + max_questions: int, + max_categories: int, + prompt_builder: PromptBuilder, +) -> dict: + """ + If previous_questions is provided, the MDL is omitted to allow the LLM to focus on + generating recommendations based on the question history. This helps provide more + contextually relevant questions that build on previous questions. + """ + + _prompt = prompt_builder.run( + documents=documents, + previous_questions=previous_questions, + language=language, + max_questions=max_questions, + max_categories=max_categories, + ) + return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} + + +@observe(as_type="generation", capture_input=False) +@trace_cost +async def generate(prompt: dict, generator: Any, generator_name: str) -> dict: + return await generator(prompt=prompt.get("prompt")), generator_name + + +@observe(capture_input=False) +def normalized(generate: dict) -> dict: + def wrapper(text: str) -> list: + text = text.replace("\n", " ") + text = " ".join(text.split()) + try: + text_list = orjson.loads(text.strip()) + return text_list + except orjson.JSONDecodeError as e: + logger.error(f"Error decoding JSON: {e}") + return [] # Return an empty list if JSON decoding fails + + reply = generate.get("replies")[0] # Expecting only one reply + normalized = wrapper(reply) + + return normalized + + +## End of Pipeline +class Question(BaseModel): + question: str + category: str + + +class QuestionResult(BaseModel): + questions: list[Question] + + +QUESTION_RECOMMENDATION_MODEL_KWARGS = { + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "question_recommendation", + "schema": QuestionResult.model_json_schema(), + }, + } +} + + class QuestionRecommendation(BasicPipeline): def __init__( self, @@ -251,7 +255,7 @@ def __init__( @observe(name="Question Recommendation") async def run( self, - mdl: dict, + contexts: list[str], previous_questions: list[str] = [], categories: list[str] = [], language: str = "en", @@ -263,7 +267,7 @@ async def run( return await self._pipe.execute( [self._final], inputs={ - "mdl": mdl, + "documents": contexts, "previous_questions": previous_questions, "categories": categories, "language": language, diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index 45a7481489..4b49fdeb77 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -10,7 +10,6 @@ from langfuse.decorators import observe from pydantic import BaseModel -from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider from src.pipelines.common import clean_up_new_lines @@ -19,6 +18,59 @@ logger = logging.getLogger("wren-ai-service") +system_prompt = """ +You are an expert in database schema design and relationship recommendation. Given a data model specification that includes various models and their attributes, your task is to analyze the models and suggest appropriate relationships between them, but only if there are clear and beneficial relationships to recommend. For each valid relationship, provide the following details: + +- **name**: A descriptive name for the relationship. +- **fromModel**: The name of the source model. +- **fromColumn**: The column in the source model that forms the relationship. +- **type**: The type of relationship, which can be "MANY_TO_ONE", "ONE_TO_MANY" or "ONE_TO_ONE" only. +- **toModel**: The name of the target model. +- **toColumn**: The column in the target model that forms the relationship. +- **reason**: The reason for recommending this relationship. + +Important guidelines: +1. Do not recommend relationships within the same model (fromModel and toModel must be different). +2. Only suggest relationships if there is a clear and beneficial reason to do so. +3. If there are no good relationships to recommend or if there are fewer than two models, return an empty list of relationships. +4. Use "MANY_TO_ONE" and "ONE_TO_MANY" instead of "MANY_TO_MANY" relationships. + +Output all relationships in the following JSON structure: + +{ + "relationships": [ + { + "name": "", + "fromModel": "", + "fromColumn": "", + "type": "", + "toModel": "", + "toColumn": "", + "reason": "" + } + ... + ] +} + +If no relationships are recommended, return: + +{ + "relationships": [] +} +""" + +user_prompt_template = """ +Here is the relationship specification for my data model: + +{{models}} + +**Please analyze these models and suggest optimizations for their relationships.** +Take into account best practices in database design, opportunities for normalization, indexing strategies, and any additional relationships that could improve data integrity and enhance query performance. + +Use this for the relationship name and reason based on the localization language: {{language}} +""" + + ## Start of Pipeline @observe(capture_input=False) def cleaned_models(mdl: dict) -> dict: @@ -82,7 +134,7 @@ def wrapper(text: str) -> str: @observe(capture_input=False) -def validated(normalized: dict, engine: Engine) -> dict: +def validated(normalized: dict) -> dict: relationships = normalized.get("relationships", []) validated_relationships = [ @@ -91,9 +143,6 @@ def validated(normalized: dict, engine: Engine) -> dict: if RelationType.is_include(relationship.get("type")) ] - # todo: after wren-engine support function to validate the relationships, we will use that function to validate the relationships - # for now, we will just return the normalized relationships - return {"relationships": validated_relationships} @@ -131,65 +180,12 @@ class RelationshipResult(BaseModel): }, } } -system_prompt = """ -You are an expert in database schema design and relationship recommendation. Given a data model specification that includes various models and their attributes, your task is to analyze the models and suggest appropriate relationships between them, but only if there are clear and beneficial relationships to recommend. For each valid relationship, provide the following details: - -- **name**: A descriptive name for the relationship. -- **fromModel**: The name of the source model. -- **fromColumn**: The column in the source model that forms the relationship. -- **type**: The type of relationship, which can be "MANY_TO_ONE", "ONE_TO_MANY" or "ONE_TO_ONE" only. -- **toModel**: The name of the target model. -- **toColumn**: The column in the target model that forms the relationship. -- **reason**: The reason for recommending this relationship. - -Important guidelines: -1. Do not recommend relationships within the same model (fromModel and toModel must be different). -2. Only suggest relationships if there is a clear and beneficial reason to do so. -3. If there are no good relationships to recommend or if there are fewer than two models, return an empty list of relationships. -4. Use "MANY_TO_ONE" and "ONE_TO_MANY" instead of "MANY_TO_MANY" relationships. - -Output all relationships in the following JSON structure: - -{ - "relationships": [ - { - "name": "", - "fromModel": "", - "fromColumn": "", - "type": "", - "toModel": "", - "toColumn": "", - "reason": "" - } - ... - ] -} - -If no relationships are recommended, return: - -{ - "relationships": [] -} -""" - -user_prompt_template = """ -Here is the relationship specification for my data model: - -{{models}} - -**Please analyze these models and suggest optimizations for their relationships.** -Take into account best practices in database design, opportunities for normalization, indexing strategies, and any additional relationships that could improve data integrity and enhance query performance. - -Use this for the relationship name and reason based on the localization language: {{language}} -""" class RelationshipRecommendation(BasicPipeline): def __init__( self, llm_provider: LLMProvider, - engine: Engine, - engine_timeout: float = 30.0, **_, ): self._components = { @@ -199,11 +195,6 @@ def __init__( generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS, ), "generator_name": llm_provider.get_model(), - "engine": engine, - } - - self._configs = { - "engine_timeout": engine_timeout, } self._final = "validated" @@ -225,6 +216,5 @@ async def run( "mdl": mdl, "language": language, **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index 395b44ab67..c83a05496f 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -17,6 +17,78 @@ logger = logging.getLogger("wren-ai-service") +system_prompt = """ +I have a data model represented in JSON format, with the following structure: + +``` +[ + {'name': 'model', 'columns': [ + {'name': 'column_1', 'type': 'type', 'properties': {} + }, + {'name': 'column_2', 'type': 'type', 'properties': {} + }, + {'name': 'column_3', 'type': 'type', 'properties': {} + } + ], 'properties': {} + } +] +``` + +Your task is to update this JSON structure by adding a `description` field inside both the `properties` attribute of each `column` and the `model` itself. +Each `description` should be derived from a user-provided input that explains the purpose or context of the `model` and its respective columns. +Follow these steps: +1. **For the `model`**: Prompt the user to provide a brief description of the model's overall purpose or its context. Insert this description in the `properties` field of the `model`. +2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `'description': 'user-provided text'`. +3. Ensure that the output is a well-formatted JSON structure, preserving the input's original format and adding the appropriate `description` fields. + +### Output Format: + +``` +{ + "models": [ + { + "name": "model", + "columns": [ + { + "name": "column_1", + "properties": { + "description": "" + } + }, + { + "name": "column_2", + "properties": { + "description": "" + } + }, + { + "name": "column_3", + "properties": { + "description": "" + } + } + ], + "properties": { + "description": "" + } + } + ] +} +``` + +Make sure that the descriptions are concise, informative, and contextually appropriate based on the input provided by the user. +""" + +user_prompt_template = """ +### Input: +User's prompt: {{ user_prompt }} +Picked models: {{ picked_models }} +Localization Language: {{ language }} + +Please provide a brief description for the model and each column based on the user's prompt. +""" + + ## Start of Pipeline @observe(capture_input=False) def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]: @@ -140,77 +212,6 @@ class SemanticResult(BaseModel): } } -system_prompt = """ -I have a data model represented in JSON format, with the following structure: - -``` -[ - {'name': 'model', 'columns': [ - {'name': 'column_1', 'type': 'type', 'properties': {} - }, - {'name': 'column_2', 'type': 'type', 'properties': {} - }, - {'name': 'column_3', 'type': 'type', 'properties': {} - } - ], 'properties': {} - } -] -``` - -Your task is to update this JSON structure by adding a `description` field inside both the `properties` attribute of each `column` and the `model` itself. -Each `description` should be derived from a user-provided input that explains the purpose or context of the `model` and its respective columns. -Follow these steps: -1. **For the `model`**: Prompt the user to provide a brief description of the model's overall purpose or its context. Insert this description in the `properties` field of the `model`. -2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `'description': 'user-provided text'`. -3. Ensure that the output is a well-formatted JSON structure, preserving the input's original format and adding the appropriate `description` fields. - -### Output Format: - -``` -{ - "models": [ - { - "name": "model", - "columns": [ - { - "name": "column_1", - "properties": { - "description": "" - } - }, - { - "name": "column_2", - "properties": { - "description": "" - } - }, - { - "name": "column_3", - "properties": { - "description": "" - } - } - ], - "properties": { - "description": "" - } - } - ] -} -``` - -Make sure that the descriptions are concise, informative, and contextually appropriate based on the input provided by the user. -""" - -user_prompt_template = """ -### Input: -User's prompt: {{ user_prompt }} -Picked models: {{ picked_models }} -Localization Language: {{ language }} - -Please provide a brief description for the model and each column based on the user's prompt. -""" - class SemanticsDescription(BasicPipeline): def __init__(self, llm_provider: LLMProvider, **_): diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 1cfe793283..a1c4c2852b 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -97,7 +97,6 @@ async def generate_sql_correction( async def post_process( generate_sql_correction: dict, post_processor: SQLGenPostProcessor, - engine_timeout: float, data_source: str, project_id: str | None = None, use_dry_plan: bool = False, @@ -105,7 +104,6 @@ async def post_process( ) -> dict: return await post_processor.run( generate_sql_correction.get("replies"), - timeout=engine_timeout, project_id=project_id, use_dry_plan=use_dry_plan, data_source=data_source, @@ -122,7 +120,6 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, - engine_timeout: float = 30.0, **kwargs, ): self._retriever = document_store_provider.get_retriever( @@ -141,10 +138,6 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } - self._configs = { - "engine_timeout": engine_timeout, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -177,6 +170,5 @@ async def run( "allow_dry_plan_fallback": allow_dry_plan_fallback, "data_source": metadata.get("data_source", "local_file"), **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 4a5d529218..27d3bc5eab 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -126,7 +126,6 @@ async def generate_sql( async def post_process( generate_sql: dict, post_processor: SQLGenPostProcessor, - engine_timeout: float, data_source: str, project_id: str | None = None, use_dry_plan: bool = False, @@ -135,7 +134,6 @@ async def post_process( ) -> dict: return await post_processor.run( generate_sql.get("replies"), - timeout=engine_timeout, project_id=project_id, use_dry_plan=use_dry_plan, data_source=data_source, @@ -153,7 +151,6 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, - engine_timeout: float = 30.0, **kwargs, ): self._retriever = document_store_provider.get_retriever( @@ -172,10 +169,6 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } - self._configs = { - "engine_timeout": engine_timeout, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -222,6 +215,5 @@ async def run( "data_source": metadata.get("data_source", "local_file"), "allow_data_preview": allow_data_preview, **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 6780116b06..0dbf2fd808 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -139,12 +139,10 @@ async def regenerate_sql( async def post_process( regenerate_sql: dict, post_processor: SQLGenPostProcessor, - engine_timeout: float, project_id: str | None = None, ) -> dict: return await post_processor.run( regenerate_sql.get("replies"), - timeout=engine_timeout, project_id=project_id, ) @@ -157,7 +155,6 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, - engine_timeout: float = 30.0, **kwargs, ): self._components = { @@ -172,10 +169,6 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } - self._configs = { - "engine_timeout": engine_timeout, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -209,6 +202,5 @@ async def run( "has_json_field": has_json_field, "sql_functions": sql_functions, **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index 723a11f28a..b40528c870 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -29,7 +29,6 @@ def __init__(self, engine: Engine): async def run( self, replies: List[str] | List[List[str]], - timeout: float = 30.0, project_id: str | None = None, use_dry_plan: bool = False, allow_dry_plan_fallback: bool = True, @@ -51,7 +50,6 @@ async def run( ) = await self._classify_generation_result( cleaned_generation_result, project_id=project_id, - timeout=timeout, use_dry_plan=use_dry_plan, allow_dry_plan_fallback=allow_dry_plan_fallback, data_source=data_source, @@ -73,7 +71,6 @@ async def run( async def _classify_generation_result( self, generation_result: str, - timeout: float, project_id: str | None = None, use_dry_plan: bool = False, allow_dry_plan_fallback: bool = True, @@ -93,7 +90,6 @@ async def _classify_generation_result( session, quoted_sql, data_source, - timeout=timeout, allow_fallback=allow_dry_plan_fallback, ) @@ -116,7 +112,6 @@ async def _classify_generation_result( quoted_sql, session, project_id=project_id, - timeout=timeout, limit=1, dry_run=True, ) @@ -141,7 +136,6 @@ async def _classify_generation_result( quoted_sql, session, project_id=project_id, - timeout=timeout, limit=1, dry_run=False, ) diff --git a/wren-ai-service/src/pipelines/retrieval/sql_executor.py b/wren-ai-service/src/pipelines/retrieval/sql_executor.py index c828ef8504..8e9625645d 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_executor.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_executor.py @@ -27,7 +27,6 @@ async def run( sql: str, project_id: str | None = None, limit: int = 500, - timeout: float = 30.0, ): async with aiohttp.ClientSession() as session: _, data, _ = await self._engine.execute_sql( @@ -36,7 +35,6 @@ async def run( project_id=project_id, dry_run=False, limit=limit, - timeout=timeout, ) return {"results": data} @@ -47,7 +45,6 @@ async def run( async def execute_sql( sql: str, data_fetcher: DataFetcher, - engine_timeout: float, project_id: str | None = None, limit: int = 500, ) -> dict: @@ -55,7 +52,6 @@ async def execute_sql( sql=sql, project_id=project_id, limit=limit, - timeout=engine_timeout, ) @@ -66,17 +62,12 @@ class SQLExecutor(BasicPipeline): def __init__( self, engine: Engine, - engine_timeout: float = 30.0, **kwargs, ): self._components = { "data_fetcher": DataFetcher(engine=engine), } - self._configs = { - "engine_timeout": engine_timeout, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -93,6 +84,5 @@ async def run( "project_id": project_id, "limit": limit, **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/retrieval/sql_functions.py b/wren-ai-service/src/pipelines/retrieval/sql_functions.py index 1e9f8e2c06..016aa9b1e6 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_functions.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_functions.py @@ -52,13 +52,11 @@ def __repr__(self): async def get_functions( engine: WrenIbis, data_source: str, - engine_timeout: float = 30.0, ) -> List[SqlFunction]: async with aiohttp.ClientSession() as session: func_list = await engine.get_func_list( session=session, data_source=data_source, - timeout=engine_timeout, ) return [ @@ -86,7 +84,6 @@ def __init__( self, engine: Engine, document_store_provider: DocumentStoreProvider, - engine_timeout: float = 30.0, ttl: int = 60 * 60 * 24, **kwargs, ) -> None: @@ -99,10 +96,6 @@ def __init__( "ttl_cache": self._cache, } - self._configs = { - "engine_timeout": engine_timeout, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -127,7 +120,6 @@ async def run( "data_source": _data_source, "project_id": project_id, **self._components, - **self._configs, } result = await self._pipe.execute(["cache"], inputs=input) return result["cache"] diff --git a/wren-ai-service/src/providers/engine/wren.py b/wren-ai-service/src/providers/engine/wren.py index 081341dfd3..c04aea5fd5 100644 --- a/wren-ai-service/src/providers/engine/wren.py +++ b/wren-ai-service/src/providers/engine/wren.py @@ -7,6 +7,7 @@ import aiohttp import orjson +from src.config import settings from src.core.engine import Engine, remove_limit_statement from src.providers.loader import provider @@ -28,7 +29,7 @@ async def execute_sql( session: aiohttp.ClientSession, project_id: str | None = None, dry_run: bool = True, - timeout: float = 30.0, + timeout: float = settings.engine_timeout, limit: int = 500, **kwargs, ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: @@ -128,7 +129,7 @@ async def execute_sql( sql: str, session: aiohttp.ClientSession, dry_run: bool = True, - timeout: float = 30.0, + timeout: float = settings.engine_timeout, limit: int = 500, **kwargs, ) -> Tuple[bool, Optional[Dict[str, Any]]]: @@ -178,7 +179,7 @@ async def dry_plan( session: aiohttp.ClientSession, sql: str, data_source: str, - timeout: float = 30.0, + timeout: float = settings.engine_timeout, allow_fallback: bool = True, **kwargs, ) -> Tuple[bool, str]: @@ -212,7 +213,7 @@ async def get_func_list( self, session: aiohttp.ClientSession, data_source: str, - timeout: float = 30.0, + timeout: float = settings.engine_timeout, ) -> list[str]: api_endpoint = f"{self._endpoint}/v3/connector/{data_source}/functions" try: @@ -247,7 +248,7 @@ async def execute_sql( sql: str, session: aiohttp.ClientSession, dry_run: bool = True, - timeout: float = 30.0, + timeout: float = settings.engine_timeout, limit: int = 500, **kwargs, ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: diff --git a/wren-ai-service/src/providers/llm/__init__.py b/wren-ai-service/src/providers/llm/__init__.py index f597d1b2ab..c3f2a47b3f 100644 --- a/wren-ai-service/src/providers/llm/__init__.py +++ b/wren-ai-service/src/providers/llm/__init__.py @@ -1,11 +1,134 @@ import logging -from typing import Any, List - -from haystack.dataclasses import ChatMessage, StreamingChunk +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional logger = logging.getLogger("wren-ai-service") +class ChatRole(str, Enum): + """Enumeration representing the roles within a chat.""" + + ASSISTANT = "assistant" + USER = "user" + SYSTEM = "system" + FUNCTION = "function" + + +@dataclass +class ChatMessage: + """ + Represents a message in a LLM chat conversation. + + :param content: The text content of the message. + :param role: The role of the entity sending the message. + :param name: The name of the function being called (only applicable for role FUNCTION). + :param meta: Additional metadata associated with the message. + """ + + content: str + role: ChatRole + name: Optional[str] = None + image_url: Optional[str] = None + meta: Dict[str, Any] = field(default_factory=dict, hash=False) + + def is_from(self, role: ChatRole) -> bool: + """ + Check if the message is from a specific role. + + :param role: The role to check against. + :returns: True if the message is from the specified role, False otherwise. + """ + return self.role == role + + @classmethod + def from_assistant( + cls, content: str, meta: Optional[Dict[str, Any]] = None + ) -> "ChatMessage": + """ + Create a message from the assistant. + + :param content: The text content of the message. + :param meta: Additional metadata associated with the message. + :returns: A new ChatMessage instance. + """ + return cls( + content, ChatRole.ASSISTANT, name=None, image_url=None, meta=meta or {} + ) + + @classmethod + def from_user(cls, content: str, image_url: Optional[str] = None) -> "ChatMessage": + """ + Create a message from the user. + + :param content: The text content of the message. + :returns: A new ChatMessage instance. + """ + return cls(content, ChatRole.USER, name=None, image_url=image_url) + + @classmethod + def from_system(cls, content: str) -> "ChatMessage": + """ + Create a message from the system. + + :param content: The text content of the message. + :returns: A new ChatMessage instance. + """ + return cls(content, ChatRole.SYSTEM, name=None, image_url=None) + + @classmethod + def from_function(cls, content: str, name: str) -> "ChatMessage": + """ + Create a message from a function call. + + :param content: The text content of the message. + :param name: The name of the function being called. + :returns: A new ChatMessage instance. + """ + return cls(content, ChatRole.FUNCTION, name=name, image_url=None, meta=None) + + def to_dict(self) -> Dict[str, Any]: + """ + Converts ChatMessage into a dictionary. + + :returns: + Serialized version of the object. + """ + data = asdict(self) + data["role"] = self.role.value + + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": + """ + Creates a new ChatMessage object from a dictionary. + + :param data: + The dictionary to build the ChatMessage object. + :returns: + The created object. + """ + data["role"] = ChatRole(data["role"]) + + return cls(**data) + + +@dataclass +class StreamingChunk: + """ + The StreamingChunk class encapsulates a segment of streamed content along with associated metadata. + + This structure facilitates the handling and processing of streamed data in a systematic manner. + + :param content: The content of the message chunk as a string. + :param meta: A dictionary containing metadata related to the message chunk. + """ + + content: str + meta: Dict[str, Any] = field(default_factory=dict, hash=False) + + def build_message(completion: Any, choice: Any) -> ChatMessage: """ Converts the response from the OpenAI API to a ChatMessage. @@ -96,3 +219,34 @@ def build_chunk(chunk: Any) -> StreamingChunk: } ) return chunk_message + + +def convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: + """ + Convert a message to the format expected by OpenAI's Chat API. + + See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. + + :returns: A dictionary with the following key: + - `role` + - `content` + - `name` (optional) + """ + openai_msg = {"role": message.role.value} + + if message.content and hasattr(message, "image_url") and message.image_url: + openai_msg["content"] = [ + {"type": "text", "text": message.content}, + {"type": "image_url", "image_url": {"url": message.image_url}}, + ] + elif message.content: + openai_msg["content"] = message.content + elif hasattr(message, "image_url") and message.image_url: + openai_msg["content"] = [ + {"type": "image_url", "image_url": {"url": message.image_url}} + ] + + if hasattr(message, "name") and message.name: + openai_msg["name"] = message.name + + return openai_msg diff --git a/wren-ai-service/src/providers/llm/litellm.py b/wren-ai-service/src/providers/llm/litellm.py index f7f193ba15..3748a945da 100644 --- a/wren-ai-service/src/providers/llm/litellm.py +++ b/wren-ai-service/src/providers/llm/litellm.py @@ -3,19 +3,17 @@ import backoff import openai -from haystack.components.generators.openai_utils import ( - _convert_message_to_openai_format, -) -from haystack.dataclasses import ChatMessage, StreamingChunk -from litellm import acompletion -from litellm.router import Router +from litellm import Router, acompletion from src.core.provider import LLMProvider from src.providers.llm import ( + ChatMessage, + StreamingChunk, build_chunk, build_message, check_finish_reason, connect_chunks, + convert_message_to_openai_format, ) from src.providers.loader import provider from src.utils import extract_braces_content, remove_trailing_slash @@ -75,11 +73,12 @@ def get_generator( @backoff.on_exception(backoff.expo, openai.APIError, max_time=60.0, max_tries=3) async def _run( prompt: str, + image_url: Optional[str] = None, history_messages: Optional[List[ChatMessage]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, query_id: Optional[str] = None, ): - message = ChatMessage.from_user(prompt) + message = ChatMessage.from_user(prompt, image_url) if system_prompt: messages = [ChatMessage.from_system(system_prompt)] if history_messages: @@ -92,7 +91,7 @@ async def _run( messages = [message] openai_formatted_messages = [ - _convert_message_to_openai_format(message) for message in messages + convert_message_to_openai_format(message) for message in messages ] generation_kwargs = { @@ -100,16 +99,16 @@ async def _run( **(generation_kwargs or {}), } - allowed_params = ( - ["reasoning_effort"] if self._model.startswith("gpt-5") else None - ) + allowed_openai_params = generation_kwargs.get( + "allowed_openai_params", [] + ) + (["reasoning_effort"] if self._model.startswith("gpt-5") else []) if self._has_fallbacks: completion = await self._router.acompletion( model=self._model, messages=openai_formatted_messages, stream=streaming_callback is not None, - allowed_openai_params=allowed_params, + allowed_openai_params=allowed_openai_params, mock_testing_fallbacks=self._enable_fallback_testing, **generation_kwargs, ) @@ -122,7 +121,7 @@ async def _run( timeout=self._timeout, messages=openai_formatted_messages, stream=streaming_callback is not None, - allowed_openai_params=allowed_params, + allowed_openai_params=allowed_openai_params, **generation_kwargs, ) diff --git a/wren-ai-service/src/web/v1/routers/__init__.py b/wren-ai-service/src/web/v1/routers/__init__.py index dfaffa0360..cf20756992 100644 --- a/wren-ai-service/src/web/v1/routers/__init__.py +++ b/wren-ai-service/src/web/v1/routers/__init__.py @@ -2,6 +2,7 @@ from src.web.v1.routers import ( ask, + ask_feedbacks, chart, chart_adjustment, instructions, @@ -17,6 +18,7 @@ router = APIRouter() router.include_router(ask.router) +router.include_router(ask_feedbacks.router) router.include_router(question_recommendation.router) router.include_router(relationship_recommendation.router) router.include_router(semantics_description.router) diff --git a/wren-ai-service/src/web/v1/routers/ask.py b/wren-ai-service/src/web/v1/routers/ask.py index 061c85c357..84bc5dbb37 100644 --- a/wren-ai-service/src/web/v1/routers/ask.py +++ b/wren-ai-service/src/web/v1/routers/ask.py @@ -11,16 +11,10 @@ get_service_metadata, ) from src.web.v1.services.ask import ( - AskFeedbackRequest, - AskFeedbackResponse, - AskFeedbackResultRequest, - AskFeedbackResultResponse, AskRequest, AskResponse, AskResultRequest, AskResultResponse, - StopAskFeedbackRequest, - StopAskFeedbackResponse, StopAskRequest, StopAskResponse, ) @@ -83,51 +77,3 @@ async def get_ask_streaming_result( service_container.ask_service.get_ask_streaming_result(query_id), media_type="text/event-stream", ) - - -@router.post("/ask-feedbacks") -async def ask_feedback( - ask_feedback_request: AskFeedbackRequest, - background_tasks: BackgroundTasks, - service_container: ServiceContainer = Depends(get_service_container), - service_metadata: ServiceMetadata = Depends(get_service_metadata), -) -> AskFeedbackResponse: - query_id = str(uuid.uuid4()) - ask_feedback_request.query_id = query_id - service_container.ask_service._ask_feedback_results[ - query_id - ] = AskFeedbackResultResponse( - status="searching", - ) - - background_tasks.add_task( - service_container.ask_service.ask_feedback, - ask_feedback_request, - service_metadata=asdict(service_metadata), - ) - return AskFeedbackResponse(query_id=query_id) - - -@router.patch("/ask-feedbacks/{query_id}") -async def stop_ask_feedback( - query_id: str, - stop_ask_feedback_request: StopAskFeedbackRequest, - background_tasks: BackgroundTasks, - service_container: ServiceContainer = Depends(get_service_container), -) -> StopAskFeedbackResponse: - stop_ask_feedback_request.query_id = query_id - background_tasks.add_task( - service_container.ask_service.stop_ask_feedback, - stop_ask_feedback_request, - ) - return StopAskFeedbackResponse(query_id=query_id) - - -@router.get("/ask-feedbacks/{query_id}") -async def get_ask_feedback_result( - query_id: str, - service_container: ServiceContainer = Depends(get_service_container), -) -> AskFeedbackResultResponse: - return service_container.ask_service.get_ask_feedback_result( - AskFeedbackResultRequest(query_id=query_id) - ) diff --git a/wren-ai-service/src/web/v1/routers/ask_feedbacks.py b/wren-ai-service/src/web/v1/routers/ask_feedbacks.py new file mode 100644 index 0000000000..0c095acc7c --- /dev/null +++ b/wren-ai-service/src/web/v1/routers/ask_feedbacks.py @@ -0,0 +1,69 @@ +import uuid +from dataclasses import asdict + +from fastapi import APIRouter, BackgroundTasks, Depends + +from src.globals import ( + ServiceContainer, + ServiceMetadata, + get_service_container, + get_service_metadata, +) +from src.web.v1.services.ask_feedback import ( + AskFeedbackRequest, + AskFeedbackResponse, + AskFeedbackResultRequest, + AskFeedbackResultResponse, + StopAskFeedbackRequest, + StopAskFeedbackResponse, +) + +router = APIRouter() + + +@router.post("/ask-feedbacks") +async def ask_feedback( + ask_feedback_request: AskFeedbackRequest, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> AskFeedbackResponse: + query_id = str(uuid.uuid4()) + ask_feedback_request.query_id = query_id + service_container.ask_feedback_service._ask_feedback_results[ + query_id + ] = AskFeedbackResultResponse( + status="searching", + ) + + background_tasks.add_task( + service_container.ask_feedback_service.ask_feedback, + ask_feedback_request, + service_metadata=asdict(service_metadata), + ) + return AskFeedbackResponse(query_id=query_id) + + +@router.patch("/ask-feedbacks/{query_id}") +async def stop_ask_feedback( + query_id: str, + stop_ask_feedback_request: StopAskFeedbackRequest, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), +) -> StopAskFeedbackResponse: + stop_ask_feedback_request.query_id = query_id + background_tasks.add_task( + service_container.ask_feedback_service.stop_ask_feedback, + stop_ask_feedback_request, + ) + return StopAskFeedbackResponse(query_id=query_id) + + +@router.get("/ask-feedbacks/{query_id}") +async def get_ask_feedback_result( + query_id: str, + service_container: ServiceContainer = Depends(get_service_container), +) -> AskFeedbackResultResponse: + return service_container.ask_feedback_service.get_ask_feedback_result( + AskFeedbackResultRequest(query_id=query_id) + ) diff --git a/wren-ai-service/src/web/v1/services/__init__.py b/wren-ai-service/src/web/v1/services/__init__.py index 525975f886..5c76e4f3fb 100644 --- a/wren-ai-service/src/web/v1/services/__init__.py +++ b/wren-ai-service/src/web/v1/services/__init__.py @@ -76,6 +76,7 @@ def query_id(self, query_id: str): # Put the services imports here to avoid circular imports and make them accessible directly to the rest of packages from .ask import AskService # noqa: E402 +from .ask_feedback import AskFeedbackService # noqa: E402 from .chart import ChartService # noqa: E402 from .chart_adjustment import ChartAdjustmentService # noqa: E402 from .instructions import InstructionsService # noqa: E402 @@ -90,6 +91,7 @@ def query_id(self, query_id: str): __all__ = [ "AskService", + "AskFeedbackService", "ChartService", "ChartAdjustmentService", "QuestionRecommendation", diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 9843307011..5e36c8a8ad 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -94,47 +94,6 @@ class AskResultResponse(_AskResultResponse): ] = Field(None, exclude=True) -# POST /v1/ask-feedbacks -class AskFeedbackRequest(BaseRequest): - question: str - tables: List[str] - sql_generation_reasoning: str - sql: str - - -class AskFeedbackResponse(BaseModel): - query_id: str - - -# PATCH /v1/ask-feedbacks/{query_id} -class StopAskFeedbackRequest(BaseRequest): - status: Literal["stopped"] - - -class StopAskFeedbackResponse(BaseModel): - query_id: str - - -# GET /v1/ask-feedbacks/{query_id} -class AskFeedbackResultRequest(BaseModel): - query_id: str - - -class AskFeedbackResultResponse(BaseModel): - status: Literal[ - "searching", - "generating", - "correcting", - "finished", - "failed", - "stopped", - ] - invalid_sql: Optional[str] = None - error: Optional[AskError] = None - response: Optional[List[AskResult]] = None - trace_id: Optional[str] = None - - class AskService: def __init__( self, @@ -152,9 +111,6 @@ def __init__( self._ask_results: Dict[str, AskResultResponse] = TTLCache( maxsize=maxsize, ttl=ttl ) - self._ask_feedback_results: Dict[str, AskFeedbackResultResponse] = TTLCache( - maxsize=maxsize, ttl=ttl - ) self._allow_sql_generation_reasoning = allow_sql_generation_reasoning self._allow_sql_functions_retrieval = allow_sql_functions_retrieval self._allow_intent_classification = allow_intent_classification @@ -697,220 +653,3 @@ async def get_ask_streaming_result( data=SSEEvent.SSEEventMessage(message=chunk), ) yield event.serialize() - - @observe(name="Ask Feedback") - @trace_metadata - async def ask_feedback( - self, - ask_feedback_request: AskFeedbackRequest, - **kwargs, - ): - trace_id = kwargs.get("trace_id") - results = { - "ask_feedback_result": {}, - "metadata": { - "error_type": "", - "error_message": "", - "request_from": ask_feedback_request.request_from, - }, - } - - query_id = ask_feedback_request.query_id - api_results = [] - error_message = None - invalid_sql = None - - try: - if not self._is_stopped(query_id, self._ask_feedback_results): - self._ask_feedback_results[query_id] = AskFeedbackResultResponse( - status="searching", - trace_id=trace_id, - ) - - ( - retrieval_task, - sql_samples_task, - instructions_task, - ) = await asyncio.gather( - self._pipelines["db_schema_retrieval"].run( - tables=ask_feedback_request.tables, - project_id=ask_feedback_request.project_id, - ), - self._pipelines["sql_pairs_retrieval"].run( - query=ask_feedback_request.question, - project_id=ask_feedback_request.project_id, - ), - self._pipelines["instructions_retrieval"].run( - query=ask_feedback_request.question, - project_id=ask_feedback_request.project_id, - scope="sql", - ), - ) - - if self._allow_sql_functions_retrieval: - sql_functions = await self._pipelines[ - "sql_functions_retrieval" - ].run( - project_id=ask_feedback_request.project_id, - ) - else: - sql_functions = [] - - # Extract results from completed tasks - _retrieval_result = retrieval_task.get( - "construct_retrieval_results", {} - ) - has_calculated_field = _retrieval_result.get( - "has_calculated_field", False - ) - has_metric = _retrieval_result.get("has_metric", False) - has_json_field = _retrieval_result.get("has_json_field", False) - documents = _retrieval_result.get("retrieval_results", []) - table_ddls = [document.get("table_ddl") for document in documents] - sql_samples = sql_samples_task["formatted_output"].get("documents", []) - instructions = instructions_task["formatted_output"].get( - "documents", [] - ) - - if not self._is_stopped(query_id, self._ask_feedback_results): - self._ask_feedback_results[query_id] = AskFeedbackResultResponse( - status="generating", - trace_id=trace_id, - ) - - text_to_sql_generation_results = await self._pipelines[ - "sql_regeneration" - ].run( - contexts=table_ddls, - sql_generation_reasoning=ask_feedback_request.sql_generation_reasoning, - sql=ask_feedback_request.sql, - project_id=ask_feedback_request.project_id, - sql_samples=sql_samples, - instructions=instructions, - has_calculated_field=has_calculated_field, - has_metric=has_metric, - has_json_field=has_json_field, - sql_functions=sql_functions, - ) - - if sql_valid_result := text_to_sql_generation_results["post_process"][ - "valid_generation_result" - ]: - api_results = [ - AskResult( - **{ - "sql": sql_valid_result.get("sql"), - "type": "llm", - } - ) - ] - elif failed_dry_run_result := text_to_sql_generation_results[ - "post_process" - ]["invalid_generation_result"]: - if failed_dry_run_result["type"] != "TIME_OUT": - self._ask_feedback_results[ - query_id - ] = AskFeedbackResultResponse( - status="correcting", - trace_id=trace_id, - ) - sql_correction_results = await self._pipelines[ - "sql_correction" - ].run( - contexts=[], - invalid_generation_result=failed_dry_run_result, - project_id=ask_feedback_request.project_id, - ) - - if valid_generation_result := sql_correction_results[ - "post_process" - ]["valid_generation_result"]: - api_results = [ - AskResult( - **{ - "sql": valid_generation_result.get("sql"), - "type": "llm", - } - ) - ] - elif failed_dry_run_result := sql_correction_results[ - "post_process" - ]["invalid_generation_result"]: - invalid_sql = failed_dry_run_result["sql"] - error_message = failed_dry_run_result["error"] - else: - invalid_sql = failed_dry_run_result["sql"] - error_message = failed_dry_run_result["error"] - - if api_results: - if not self._is_stopped(query_id, self._ask_feedback_results): - self._ask_feedback_results[query_id] = AskFeedbackResultResponse( - status="finished", - response=api_results, - trace_id=trace_id, - ) - results["ask_feedback_result"] = api_results - else: - logger.exception("ask feedback pipeline - NO_RELEVANT_SQL") - if not self._is_stopped(query_id, self._ask_feedback_results): - self._ask_feedback_results[query_id] = AskFeedbackResultResponse( - status="failed", - error=AskError( - code="NO_RELEVANT_SQL", - message=error_message or "No relevant SQL", - ), - invalid_sql=invalid_sql, - trace_id=trace_id, - ) - results["metadata"]["error_type"] = "NO_RELEVANT_SQL" - results["metadata"]["error_message"] = error_message - - return results - - except Exception as e: - logger.exception(f"ask feedback pipeline - OTHERS: {e}") - - self._ask_feedback_results[query_id] = AskFeedbackResultResponse( - status="failed", - error=AskError( - code="OTHERS", - message=str(e), - ), - trace_id=trace_id, - ) - - results["metadata"]["error_type"] = "OTHERS" - results["metadata"]["error_message"] = str(e) - return results - - def stop_ask_feedback( - self, - stop_ask_feedback_request: StopAskFeedbackRequest, - ): - self._ask_feedback_results[ - stop_ask_feedback_request.query_id - ] = AskFeedbackResultResponse( - status="stopped", - ) - - def get_ask_feedback_result( - self, - ask_feedback_result_request: AskFeedbackResultRequest, - ) -> AskFeedbackResultResponse: - if ( - result := self._ask_feedback_results.get( - ask_feedback_result_request.query_id - ) - ) is None: - logger.exception( - f"ask feedback pipeline - OTHERS: {ask_feedback_result_request.query_id} is not found" - ) - return AskFeedbackResultResponse( - status="failed", - error=AskError( - code="OTHERS", - message=f"{ask_feedback_result_request.query_id} is not found", - ), - ) - - return result diff --git a/wren-ai-service/src/web/v1/services/ask_feedback.py b/wren-ai-service/src/web/v1/services/ask_feedback.py new file mode 100644 index 0000000000..2e9cc227df --- /dev/null +++ b/wren-ai-service/src/web/v1/services/ask_feedback.py @@ -0,0 +1,295 @@ +import asyncio +import logging +from typing import Dict, List, Literal, Optional + +from cachetools import TTLCache +from langfuse.decorators import observe +from pydantic import BaseModel + +from src.core.pipeline import BasicPipeline +from src.utils import trace_metadata +from src.web.v1.services import BaseRequest +from src.web.v1.services.ask import AskError, AskResult + +logger = logging.getLogger("wren-ai-service") + + +# POST /v1/ask-feedbacks +class AskFeedbackRequest(BaseRequest): + question: str + tables: List[str] + sql_generation_reasoning: str + sql: str + + +class AskFeedbackResponse(BaseModel): + query_id: str + + +# PATCH /v1/ask-feedbacks/{query_id} +class StopAskFeedbackRequest(BaseRequest): + status: Literal["stopped"] + + +class StopAskFeedbackResponse(BaseModel): + query_id: str + + +# GET /v1/ask-feedbacks/{query_id} +class AskFeedbackResultRequest(BaseModel): + query_id: str + + +class AskFeedbackResultResponse(BaseModel): + status: Literal[ + "searching", + "generating", + "correcting", + "finished", + "failed", + "stopped", + ] + invalid_sql: Optional[str] = None + error: Optional[AskError] = None + response: Optional[List[AskResult]] = None + trace_id: Optional[str] = None + + +class AskFeedbackService: + def __init__( + self, + pipelines: Dict[str, BasicPipeline], + allow_sql_functions_retrieval: bool = True, + maxsize: int = 1_000_000, + ttl: int = 120, + ): + self._pipelines = pipelines + self._ask_feedback_results: Dict[str, AskFeedbackResultResponse] = TTLCache( + maxsize=maxsize, ttl=ttl + ) + self._allow_sql_functions_retrieval = allow_sql_functions_retrieval + + def _is_stopped(self, query_id: str, container: dict): + if ( + result := container.get(query_id) + ) is not None and result.status == "stopped": + return True + + return False + + @observe(name="Ask Feedback") + @trace_metadata + async def ask_feedback( + self, + ask_feedback_request: AskFeedbackRequest, + **kwargs, + ): + trace_id = kwargs.get("trace_id") + results = { + "ask_feedback_result": {}, + "metadata": { + "error_type": "", + "error_message": "", + "request_from": ask_feedback_request.request_from, + }, + } + + query_id = ask_feedback_request.query_id + api_results = [] + error_message = None + invalid_sql = None + + try: + if not self._is_stopped(query_id, self._ask_feedback_results): + self._ask_feedback_results[query_id] = AskFeedbackResultResponse( + status="searching", + trace_id=trace_id, + ) + + ( + retrieval_task, + sql_samples_task, + instructions_task, + ) = await asyncio.gather( + self._pipelines["db_schema_retrieval"].run( + tables=ask_feedback_request.tables, + project_id=ask_feedback_request.project_id, + ), + self._pipelines["sql_pairs_retrieval"].run( + query=ask_feedback_request.question, + project_id=ask_feedback_request.project_id, + ), + self._pipelines["instructions_retrieval"].run( + query=ask_feedback_request.question, + project_id=ask_feedback_request.project_id, + scope="sql", + ), + ) + + if self._allow_sql_functions_retrieval: + sql_functions = await self._pipelines[ + "sql_functions_retrieval" + ].run( + project_id=ask_feedback_request.project_id, + ) + else: + sql_functions = [] + + # Extract results from completed tasks + _retrieval_result = retrieval_task.get( + "construct_retrieval_results", {} + ) + has_calculated_field = _retrieval_result.get( + "has_calculated_field", False + ) + has_metric = _retrieval_result.get("has_metric", False) + has_json_field = _retrieval_result.get("has_json_field", False) + documents = _retrieval_result.get("retrieval_results", []) + table_ddls = [document.get("table_ddl") for document in documents] + sql_samples = sql_samples_task["formatted_output"].get("documents", []) + instructions = instructions_task["formatted_output"].get( + "documents", [] + ) + + if not self._is_stopped(query_id, self._ask_feedback_results): + self._ask_feedback_results[query_id] = AskFeedbackResultResponse( + status="generating", + trace_id=trace_id, + ) + + text_to_sql_generation_results = await self._pipelines[ + "sql_regeneration" + ].run( + contexts=table_ddls, + sql_generation_reasoning=ask_feedback_request.sql_generation_reasoning, + sql=ask_feedback_request.sql, + project_id=ask_feedback_request.project_id, + sql_samples=sql_samples, + instructions=instructions, + has_calculated_field=has_calculated_field, + has_metric=has_metric, + has_json_field=has_json_field, + sql_functions=sql_functions, + ) + + if sql_valid_result := text_to_sql_generation_results["post_process"][ + "valid_generation_result" + ]: + api_results = [ + AskResult( + **{ + "sql": sql_valid_result.get("sql"), + "type": "llm", + } + ) + ] + elif failed_dry_run_result := text_to_sql_generation_results[ + "post_process" + ]["invalid_generation_result"]: + if failed_dry_run_result["type"] != "TIME_OUT": + self._ask_feedback_results[ + query_id + ] = AskFeedbackResultResponse( + status="correcting", + trace_id=trace_id, + ) + sql_correction_results = await self._pipelines[ + "sql_correction" + ].run( + contexts=[], + invalid_generation_result=failed_dry_run_result, + project_id=ask_feedback_request.project_id, + ) + + if valid_generation_result := sql_correction_results[ + "post_process" + ]["valid_generation_result"]: + api_results = [ + AskResult( + **{ + "sql": valid_generation_result.get("sql"), + "type": "llm", + } + ) + ] + elif failed_dry_run_result := sql_correction_results[ + "post_process" + ]["invalid_generation_result"]: + invalid_sql = failed_dry_run_result["sql"] + error_message = failed_dry_run_result["error"] + else: + invalid_sql = failed_dry_run_result["sql"] + error_message = failed_dry_run_result["error"] + + if api_results: + if not self._is_stopped(query_id, self._ask_feedback_results): + self._ask_feedback_results[query_id] = AskFeedbackResultResponse( + status="finished", + response=api_results, + trace_id=trace_id, + ) + results["ask_feedback_result"] = api_results + else: + logger.exception("ask feedback pipeline - NO_RELEVANT_SQL") + if not self._is_stopped(query_id, self._ask_feedback_results): + self._ask_feedback_results[query_id] = AskFeedbackResultResponse( + status="failed", + error=AskError( + code="NO_RELEVANT_SQL", + message=error_message or "No relevant SQL", + ), + invalid_sql=invalid_sql, + trace_id=trace_id, + ) + results["metadata"]["error_type"] = "NO_RELEVANT_SQL" + results["metadata"]["error_message"] = error_message + + return results + + except Exception as e: + logger.exception(f"ask feedback pipeline - OTHERS: {e}") + + self._ask_feedback_results[query_id] = AskFeedbackResultResponse( + status="failed", + error=AskError( + code="OTHERS", + message=str(e), + ), + trace_id=trace_id, + ) + + results["metadata"]["error_type"] = "OTHERS" + results["metadata"]["error_message"] = str(e) + return results + + def stop_ask_feedback( + self, + stop_ask_feedback_request: StopAskFeedbackRequest, + ): + self._ask_feedback_results[ + stop_ask_feedback_request.query_id + ] = AskFeedbackResultResponse( + status="stopped", + ) + + def get_ask_feedback_result( + self, + ask_feedback_result_request: AskFeedbackResultRequest, + ) -> AskFeedbackResultResponse: + if ( + result := self._ask_feedback_results.get( + ask_feedback_result_request.query_id + ) + ) is None: + logger.exception( + f"ask feedback pipeline - OTHERS: {ask_feedback_result_request.query_id} is not found" + ) + return AskFeedbackResultResponse( + status="failed", + error=AskError( + code="OTHERS", + message=f"{ask_feedback_result_request.query_id} is not found", + ), + ) + + return result diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index 8aea7222de..24df1e3e62 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -164,17 +164,17 @@ class Request(BaseRequest): regenerate: bool = False allow_data_preview: bool = True - async def _recommend(self, request: dict, input: Request): + async def _recommend(self, request: dict): resp = await self._pipelines["question_recommendation"].run(**request) questions = resp.get("normalized", {}).get("questions", []) validation_tasks = [ self._validate_question( question, - input.event_id, - input.max_questions, - input.max_categories, - input.project_id, - input.allow_data_preview, + request["event_id"], + request["max_questions"], + request["max_categories"], + project_id=request["project_id"], + allow_data_preview=request["allow_data_preview"], ) for question in questions ] @@ -190,15 +190,27 @@ async def recommend(self, input: Request, **kwargs) -> Event: trace_id = kwargs.get("trace_id") try: + mdl = orjson.loads(input.mdl) + retrieval_result = await self._pipelines["db_schema_retrieval"].run( + tables=[model["name"] for model in mdl["models"]], + project_id=input.project_id, + ) + _retrieval_result = retrieval_result.get("construct_retrieval_results", {}) + documents = _retrieval_result.get("retrieval_results", []) + table_ddls = [document.get("table_ddl") for document in documents] + request = { - "mdl": orjson.loads(input.mdl), + "contexts": table_ddls, "previous_questions": input.previous_questions, "language": input.configurations.language, "max_questions": input.max_questions, "max_categories": input.max_categories, + "project_id": input.project_id, + "event_id": input.event_id, + "allow_data_preview": input.allow_data_preview, } - await self._recommend(request, input) + await self._recommend(request) resource = self._cache[input.event_id] resource.trace_id = trace_id @@ -223,7 +235,6 @@ async def recommend(self, input: Request, **kwargs) -> Event: "categories": categories, "max_categories": len(categories), }, - input, ) self._cache[input.event_id].status = "finished" diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 8dc7dd6594..f2c01e95a5 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -117,13 +117,8 @@ pipes: llm: litellm_llm.default - name: relationship_recommendation llm: litellm_llm.default - engine: wren_ui - name: question_recommendation llm: litellm_llm.default - - name: question_recommendation_db_schema_retrieval - llm: litellm_llm.default - embedder: litellm_embedder.default - document_store: qdrant - name: question_recommendation_sql_generation llm: litellm_llm.default engine: wren_ui diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 2f2a954dcf..bb688dcfd8 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -117,13 +117,8 @@ pipes: llm: litellm_llm.default - name: relationship_recommendation llm: litellm_llm.default - engine: wren_ui - name: question_recommendation llm: litellm_llm.default - - name: question_recommendation_db_schema_retrieval - llm: litellm_llm.default - embedder: litellm_embedder.default - document_store: qdrant - name: question_recommendation_sql_generation llm: litellm_llm.default engine: wren_ui