-
Notifications
You must be signed in to change notification settings - Fork 1.4k
chore(wren-ai-service): minor updates (ai-env-changed) #1893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughRemoves per-call engine timeouts across generation/retrieval pipelines; refactors question recommendation to use contexts/documents instead of MDL; introduces AskFeedbackService with dedicated router; centralizes shared pipelines in the service container; updates LLM provider types and litellm integration; adjusts engine default timeouts via settings; tidies prompts in semantics/relationship modules; updates configs. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Router as ask_feedbacks router
participant Service as AskFeedbackService
participant Pipelines as Shared Pipelines
Client->>Router: POST /ask-feedbacks (AskFeedbackRequest)
Router->>Service: ask_feedback(request)
Service->>Service: seed result (status=searching)
Service->>Pipelines: db_schema/sql_pairs/instructions retrieval (parallel)
Pipelines-->>Service: retrieval results
Service->>Pipelines: sql regeneration + post-process
Pipelines-->>Service: results or correction required
Service->>Service: update status (generating/correcting/finished/failed)
Client-->>Router: GET /ask-feedbacks/{query_id}
Router->>Service: get_ask_feedback_result(query_id)
Service-->>Router: current result
Router-->>Client: AskFeedbackResultResponse
sequenceDiagram
participant Client
participant QRService as QuestionRecommendation Service
participant Retrieval as DB Schema Retrieval
participant QRPipe as QuestionRecommendation Pipeline
Client->>QRService: recommend(input with mdl)
QRService->>Retrieval: fetch schema from mdl.models
Retrieval-->>QRService: table DDLs
QRService->>QRPipe: run(contexts=table DDLs, previous_questions, ...)
QRPipe-->>QRService: categories + questions
QRService->>QRPipe: (if needed) run with categories to fill
QRPipe-->>QRService: final questions
QRService-->>Client: response
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🔭 Outside diff range comments (5)
wren-ai-service/src/providers/engine/wren.py (3)
127-176: Fix return type annotation and unify return shape in WrenIbis.execute_sqlThe function returns three values throughout, but the annotation declares a 2-tuple. Additionally, the timeout except branch returns a string (not a metadata dict), diverging from other branches.
- Align the return type to a 3-tuple with a metadata dict.
- Normalize the timeout branch to return a metadata dict (with error_message and correlation_id) for consistency.
Apply this diff:
- ) -> Tuple[bool, Optional[Dict[str, Any]]]: + ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: @@ - except asyncio.TimeoutError: - return False, None, f"Request timed out: {timeout} seconds" + except asyncio.TimeoutError: + return ( + False, + None, + { + "error_message": f"Request timed out: {timeout} seconds", + "correlation_id": "", + }, + )
246-297: Correct WrenEngine.execute_sql return annotation and timeout return shapeThe function returns (bool, payload, metadata_dict) but the annotation declares the third element as an Optional[str]. The timeout handler also returns a plain string instead of a structured metadata dict.
- Update the annotation to a 3-tuple where the third element is a metadata dict.
- Normalize the timeout branch to return a consistent metadata dict.
- ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: + ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: @@ - except asyncio.TimeoutError: - return False, None, f"Request timed out: {timeout} seconds" + except asyncio.TimeoutError: + return ( + False, + None, + { + "error_message": f"Request timed out: {timeout} seconds", + "correlation_id": "", + }, + )
212-221: Use aiohttp.ClientTimeout for get_func_list to avoid passing a raw floatOther methods already pass aiohttp.ClientTimeout(total=timeout). Passing a float to the request timeout parameter is not the intended API and may not behave as expected.
- async with session.get(api_endpoint, timeout=timeout) as response: + async with session.get( + api_endpoint, timeout=aiohttp.ClientTimeout(total=timeout) + ) as response:wren-ai-service/src/providers/llm/__init__.py (1)
224-253: Incorrect return type for convert_message_to_openai_format; doc nit; tighten vision handling
- The function returns nested structures (lists/dicts), not Dict[str, str].
- Minor doc grammar: “keys” not “key”.
- Optional: restrict image content formatting to user messages; sending vision payloads for system/assistant is typically invalid.
Apply this diff to fix the type and doc nit:
-def convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: +def convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]: @@ - :returns: A dictionary with the following key: + :returns: A dictionary with the following keys:Optionally, tighten image handling to user role only and remove redundant hasattr checks:
- if message.content and hasattr(message, "image_url") and message.image_url: + if message.role == ChatRole.USER and message.image_url and message.content: openai_msg["content"] = [ {"type": "text", "text": message.content}, {"type": "image_url", "image_url": {"url": message.image_url}}, ] - elif message.content: + elif message.content: openai_msg["content"] = message.content - elif hasattr(message, "image_url") and message.image_url: + elif message.role == ChatRole.USER and message.image_url: openai_msg["content"] = [ {"type": "image_url", "image_url": {"url": message.image_url}} ] - if hasattr(message, "name") and message.name: + if message.name: openai_msg["name"] = message.namewren-ai-service/src/providers/llm/litellm.py (1)
141-143: Fix streaming_callback signature mismatch in LitellmLLMProviderlitellm.get_generator currently types streaming_callback as Callable[[StreamingChunk], None] but calls it with (chunk_delta, query_id). Pipeline callbacks are defined as def _streaming_callback(self, chunk, query_id): — update the provider signature to accept the second (optional) query_id.
Files to change:
- wren-ai-service/src/providers/llm/litellm.py — update get_generator parameter annotation (around the streaming_callback definition and its call at lines ~141-143).
Suggested diff:
- streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[Callable[[StreamingChunk, Optional[str]], None]] = None,After this change, the invocation streaming_callback(chunk_delta, query_id) will match the annotated signature. Please also scan other LLM provider implementations if added to keep signatures consistent.
🧹 Nitpick comments (9)
wren-ai-service/src/providers/engine/wren.py (1)
32-35: Optional: Avoid binding a mutable default config value at definition timeDefault arguments are evaluated at definition time. If settings.engine_timeout is changed at runtime (e.g., hot-reload), methods won’t pick up the new default. Consider defaulting to None and reading settings at call time.
- timeout: float = settings.engine_timeout, + timeout: float | None = None, @@ - async with session.post( + timeout = timeout if timeout is not None else settings.engine_timeout + async with session.post(Note: Apply the same pattern to WrenIbis.execute_sql, dry_plan, get_func_list and WrenEngine.execute_sql for consistency.
wren-ai-service/src/web/v1/services/ask_feedback.py (2)
265-274: Stopping does not cancel in-flight work; consider cooperative cancellationMarking status as “stopped” prevents subsequent updates, but ongoing awaits continue running. For large workloads, this wastes resources.
Approach:
- Track asyncio.Tasks per query_id (e.g., self._tasks[query_id] = current task).
- In stop_ask_feedback, cancel and handle asyncio.CancelledError in ask_feedback to exit early.
- Periodically check _is_stopped between major awaits if full cancellation is not feasible.
279-287: Use warning instead of exception level for “not found” query_idThis isn’t an exception scenario; logging it as exception may create noisy stacks.
- logger.exception( + logger.warning( f"ask feedback pipeline - OTHERS: {ask_feedback_result_request.query_id} is not found" )wren-ai-service/src/pipelines/generation/semantics_description.py (1)
20-90: Prompt cleanup: fix example placeholders and JSON hinting to reduce invalid JSON
- The example shows “description for column_1” repeated for columns 2 and 3. Adjust to match the column names to avoid confusing the model.
- In step 2, the example uses single quotes around JSON keys/values. Since the output must be valid JSON (and you enforce a json_schema), prefer double quotes to reduce invalid generations.
-2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `'description': 'user-provided text'`. +2. **For each `column`**: Ask the user to describe each column's role or significance. Each column's description should be added under its respective `properties` field in the format: `"description": "user-provided text"`. @@ - "description": "<description for column_1>" + "description": "<description for column_1>" @@ - "description": "<description for column_1>" + "description": "<description for column_2>" @@ - "description": "<description for column_1>" + "description": "<description for column_3>"wren-ai-service/src/globals.py (1)
51-82: Shared pipelines: confirm reentrancy/concurrency-safety and consider configuration parityYou’ve centralized core pipelines (retrieval/indexing/executor). This is great for consistency and reuse. Two follow-ups:
- Concurrency: These pipeline instances (and their internal components/caches) will now be shared across multiple services and requests. Please confirm all involved components (retrievers, generators, caches, AsyncDriver) are reentrant and safe under concurrent access.
- Configuration parity:
SqlFunctionsis instantiated with defaults (e.g., TTL). If previously configurable viasettings, consider threading those through to keep behavior parity.Would you like me to generate a quick script to scan for stateful members (e.g., in-memory caches) across these pipeline classes to help assess reentrancy?
wren-ai-service/src/web/v1/services/question_recommendation.py (1)
168-176: Harden handling of malformed normalized payloadsEven with the upstream fix to always return a dict, hardening here avoids surprises if upstream behavior changes.
Apply this diff:
- resp = await self._pipelines["question_recommendation"].run(**request) - questions = resp.get("normalized", {}).get("questions", []) + resp = await self._pipelines["question_recommendation"].run(**request) + normalized = resp.get("normalized") or {} + if not isinstance(normalized, dict): + normalized = {} + questions = normalized.get("questions", [])wren-ai-service/src/providers/llm/__init__.py (3)
59-68: Document the image_url parameter in from_userThe docstring omits the
image_urlparameter.def from_user(cls, content: str, image_url: Optional[str] = None) -> "ChatMessage": """ Create a message from the user. - :param content: The text content of the message. + :param content: The text content of the message. + :param image_url: Optional image URL to include with the user message. :returns: A new ChatMessage instance. """
102-115: Avoid mutating caller’s dict in from_dict
from_dictmutates the input dict (data["role"] = ...). Prefer copying to prevent side effects on callers that reuse the input.- data["role"] = ChatRole(data["role"]) - - return cls(**data) + # Avoid mutating caller's dict + return cls(**{**data, "role": ChatRole(data["role"])})
9-16: Consider adding TOOL role for parity with modern Chat API tool messagesIf you plan to support tool outputs/messages, include
TOOL = "tool"for completeness (OpenAI returns tool role messages). Safe to add now; adopt where needed later.class ChatRole(str, Enum): @@ ASSISTANT = "assistant" USER = "user" SYSTEM = "system" FUNCTION = "function" + TOOL = "tool"Would you like me to propagate TOOL handling in conversion and constructors?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (25)
deployment/kustomizations/base/cm.yaml(0 hunks)docker/config.example.yaml(0 hunks)wren-ai-service/src/globals.py(9 hunks)wren-ai-service/src/pipelines/generation/followup_sql_generation.py(0 hunks)wren-ai-service/src/pipelines/generation/question_recommendation.py(6 hunks)wren-ai-service/src/pipelines/generation/relationship_recommendation.py(2 hunks)wren-ai-service/src/pipelines/generation/semantics_description.py(1 hunks)wren-ai-service/src/pipelines/generation/sql_correction.py(0 hunks)wren-ai-service/src/pipelines/generation/sql_generation.py(0 hunks)wren-ai-service/src/pipelines/generation/sql_regeneration.py(0 hunks)wren-ai-service/src/pipelines/generation/utils/sql.py(0 hunks)wren-ai-service/src/pipelines/retrieval/sql_executor.py(0 hunks)wren-ai-service/src/pipelines/retrieval/sql_functions.py(0 hunks)wren-ai-service/src/providers/engine/wren.py(6 hunks)wren-ai-service/src/providers/llm/__init__.py(2 hunks)wren-ai-service/src/providers/llm/litellm.py(4 hunks)wren-ai-service/src/web/v1/routers/__init__.py(2 hunks)wren-ai-service/src/web/v1/routers/ask.py(0 hunks)wren-ai-service/src/web/v1/routers/ask_feedbacks.py(1 hunks)wren-ai-service/src/web/v1/services/__init__.py(2 hunks)wren-ai-service/src/web/v1/services/ask.py(0 hunks)wren-ai-service/src/web/v1/services/ask_feedback.py(1 hunks)wren-ai-service/src/web/v1/services/question_recommendation.py(2 hunks)wren-ai-service/tools/config/config.example.yaml(0 hunks)wren-ai-service/tools/config/config.full.yaml(0 hunks)
💤 Files with no reviewable changes (13)
- docker/config.example.yaml
- wren-ai-service/src/pipelines/generation/followup_sql_generation.py
- wren-ai-service/src/pipelines/retrieval/sql_functions.py
- wren-ai-service/src/web/v1/services/ask.py
- wren-ai-service/tools/config/config.full.yaml
- wren-ai-service/src/pipelines/generation/sql_regeneration.py
- deployment/kustomizations/base/cm.yaml
- wren-ai-service/src/pipelines/generation/sql_correction.py
- wren-ai-service/tools/config/config.example.yaml
- wren-ai-service/src/pipelines/generation/sql_generation.py
- wren-ai-service/src/web/v1/routers/ask.py
- wren-ai-service/src/pipelines/retrieval/sql_executor.py
- wren-ai-service/src/pipelines/generation/utils/sql.py
🧰 Additional context used
🧬 Code Graph Analysis (8)
wren-ai-service/src/web/v1/services/__init__.py (2)
wren-ai-service/src/web/v1/routers/ask_feedbacks.py (1)
ask_feedback(25-44)wren-ai-service/src/web/v1/services/ask_feedback.py (2)
ask_feedback(82-263)AskFeedbackService(58-295)
wren-ai-service/src/web/v1/services/ask_feedback.py (4)
wren-ai-service/src/core/pipeline.py (1)
BasicPipeline(14-20)wren-ai-service/src/web/v1/services/__init__.py (3)
BaseRequest(58-74)query_id(69-70)query_id(73-74)wren-ai-service/src/web/v1/services/ask.py (1)
AskError(55-57)wren-ai-service/src/web/v1/routers/ask_feedbacks.py (3)
ask_feedback(25-44)stop_ask_feedback(48-59)get_ask_feedback_result(63-69)
wren-ai-service/src/web/v1/routers/ask_feedbacks.py (3)
wren-ai-service/src/globals.py (4)
ServiceContainer(17-30)ServiceMetadata(34-36)get_service_container(259-262)get_service_metadata(314-317)wren-ai-service/src/web/v1/services/ask_feedback.py (9)
ask_feedback(82-263)AskFeedbackRequest(18-22)AskFeedbackResponse(25-26)AskFeedbackResultRequest(39-40)AskFeedbackResultResponse(43-55)StopAskFeedbackRequest(30-31)StopAskFeedbackResponse(34-35)stop_ask_feedback(265-273)get_ask_feedback_result(275-295)wren-ai-service/src/web/v1/services/__init__.py (2)
query_id(69-70)query_id(73-74)
wren-ai-service/src/globals.py (11)
wren-ai-service/src/web/v1/services/ask_feedback.py (1)
AskFeedbackService(58-295)wren-ai-service/src/pipelines/retrieval/instructions.py (2)
retrieval(85-105)Instructions(185-228)wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py (2)
retrieval(64-83)SqlPairsRetrieval(116-156)wren-ai-service/src/pipelines/indexing/sql_pairs.py (1)
SqlPairs(166-227)wren-ai-service/src/pipelines/indexing/instructions.py (1)
Instructions(127-180)wren-ai-service/src/pipelines/generation/sql_correction.py (1)
SQLCorrection(117-174)wren-ai-service/src/pipelines/retrieval/sql_functions.py (1)
SqlFunctions(82-125)wren-ai-service/src/pipelines/retrieval/sql_executor.py (1)
SQLExecutor(61-88)wren-ai-service/src/pipelines/generation/sql_generation.py (1)
SQLGeneration(148-219)wren-ai-service/src/pipelines/generation/sql_regeneration.py (1)
SQLRegeneration(153-206)wren-ai-service/src/web/v1/services/chart.py (1)
ChartService(62-193)
wren-ai-service/src/web/v1/services/question_recommendation.py (1)
wren-ai-service/src/pipelines/generation/question_recommendation.py (1)
run(256-278)
wren-ai-service/src/providers/llm/litellm.py (1)
wren-ai-service/src/providers/llm/__init__.py (8)
ChatMessage(19-114)StreamingChunk(118-129)build_chunk(200-221)build_message(132-154)check_finish_reason(157-179)connect_chunks(182-197)convert_message_to_openai_format(224-252)from_user(60-67)
wren-ai-service/src/pipelines/generation/relationship_recommendation.py (1)
wren-ai-service/src/pipelines/generation/question_recommendation.py (1)
normalized(196-210)
wren-ai-service/src/pipelines/generation/question_recommendation.py (10)
wren-ai-service/src/pipelines/generation/relationship_recommendation.py (2)
prompt(102-108)run(207-220)wren-ai-service/src/pipelines/generation/semantics_description.py (2)
prompt(130-141)run(233-250)wren-ai-service/src/pipelines/generation/misleading_assistance.py (2)
prompt(56-75)run(150-171)wren-ai-service/src/pipelines/generation/user_guide_assistance.py (2)
prompt(52-65)run(142-160)wren-ai-service/src/pipelines/generation/data_assistance.py (2)
prompt(56-75)run(150-171)wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py (2)
prompt(58-76)run(155-176)wren-ai-service/src/pipelines/generation/chart_generation.py (2)
prompt(66-85)run(153-175)wren-ai-service/src/pipelines/generation/sql_tables_extraction.py (2)
prompt(60-65)run(121-132)wren-ai-service/src/pipelines/generation/intent_classification.py (2)
prompt(269-290)run(378-400)wren-ai-service/src/pipelines/common.py (1)
clean_up_new_lines(111-112)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: pytest
- GitHub Check: pytest
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (go)
🔇 Additional comments (27)
wren-ai-service/src/providers/engine/wren.py (2)
85-101: Correlation ID extraction inconsistency in WrenUI.execute_sqlOn success, correlation_id is read via res_json.get("correlationId", ""), while on error it’s read from extensions.other.correlationId. Please verify the GraphQL response shape and standardize extraction logic to avoid missing correlation IDs.
Do success and error responses provide correlationId in the same location? If not, can we harmonize using a helper (e.g., prefer extensions.other.correlationId when available, else fallback to top-level)?
261-273: Verify HTTP method and body semantics for WrenEngine endpointsThe code uses GET with a JSON body (json=...). This is uncommon and some servers/proxies may drop bodies on GET. Verify the WrenEngine API expects GET+JSON; otherwise switch to POST.
If the API accepts POST, change to session.post(...) to avoid interoperability issues.
wren-ai-service/src/web/v1/services/ask_feedback.py (2)
80-87: Verify trace_id propagation from trace_metadataask_feedback reads trace_id from kwargs. Ensure the trace_metadata decorator injects a “trace_id” kwarg, or update to read from a well-defined context.
If trace_metadata populates a different key (e.g., service_metadata["trace_id"]), adapt the code accordingly to avoid None trace IDs in responses.
175-223: Defensive access to post_process fields to avoid KeyErrorAccessing ["post_process"]["valid_generation_result"] and ["invalid_generation_result"] assumes a fixed shape. If upstream pipelines change or fail silently, this can raise KeyError.
Consider using .get with sensible defaults, or validate structure before access. If you want, I can draft a small helper to unpack generation/correction results safely.
wren-ai-service/src/web/v1/routers/__init__.py (1)
21-21: LGTM: ask-feedback router integrationRouter registration is correct and ordering is consistent with other routes.
wren-ai-service/src/web/v1/services/__init__.py (1)
79-95: LGTM: AskFeedbackService exported in services packagePublic export and import look good; aligns with new router and DI wiring.
wren-ai-service/src/globals.py (9)
19-19: Wiring AskFeedbackService into the container looks correctThe new
ask_feedback_servicefield is properly added toServiceContainerand later instantiated increate_service_container.
104-106: Good reuse for indexing pipelinesWiring the shared
sql_pairsandinstructionsindexing pipelines intoSemanticsPreparationServicereduces duplication and ensures consistent indexing behavior across services.
128-149: AskService wiring leverages shared retrieval and correction pipelinesUsing the shared
db_schema_retrieval,sql_pairs_retrieval,instructions_retrieval,sql_correction, andsql_functions_retrievalpipelines is a solid consolidation that should simplify maintenance.
172-180: Chart services: good consolidation around a shared SQLExecutorCentralizing
sql_executoravoids duplicate engine wiring and should reduce resource usage.Also applies to: 183-189
214-221: QuestionRecommendation: retrieval/function reuse aligns with contexts-based promptThe wiring for
db_schema_retrieval,sql_pairs_retrieval,instructions_retrieval, andsql_functions_retrievalmatches the updated contexts-driven question recommendation pipeline.
227-227: SqlPairsService: indexing pipeline reuse looks goodPassing the shared
_sql_pair_indexing_pipelinekeeps indexing consistent.
241-241: InstructionsService: indexing reuse looks goodSame consolidation benefit as SQL pairs.
250-252: SqlCorrectionService: good reuse of shared retrieval and correction pipelinesConsolidation should help with behavior parity across services that do SQL correction.
158-171: Incorrect — AskFeedbackService does not require adding "sql_generation"I checked wren-ai-service/src/web/v1/services/ask_feedback.py: the service calls self._pipelines["sql_regeneration"] (and "sql_correction"), but never self._pipelines["sql_generation"]. Therefore the original suggestion to add an "sql_generation" mapping for AskFeedbackService is not needed.
Relevant locations:
- wren-ai-service/src/web/v1/services/ask_feedback.py — uses "sql_regeneration" and "sql_correction" (no "sql_generation").
- wren-ai-service/src/web/v1/services/question_recommendation.py:115 — does call self._pipelines["sql_generation"].
- wren-ai-service/src/globals.py — already defines "sql_generation" elsewhere.
Action: ignore the suggested diff for AskFeedbackService; no change required here.
Likely an incorrect or invalid review comment.
wren-ai-service/src/pipelines/generation/question_recommendation.py (2)
152-157: Template integration for documents is aligned with contexts-based flowThe Jinja block for documents looks correct and should render the schema contexts as intended.
255-279: Run method aligns with contexts-driven inputsSwitch to
contextsand passing them asdocumentsmatches the prompt signature change and the service layer’s new flow.wren-ai-service/src/web/v1/services/question_recommendation.py (3)
167-178: Good decoupling of _recommend from the Pydantic RequestSwitching
_recommendto accept a plain dict simplifies reuse across the two passes.
202-211: Contexts construction is correct and consistent with the pipeline changesPassing
contextsbased ontable_ddlsaligns with the new documents/contexts-driven prompt.
213-231: Two-pass flow is sound; status transition logic is clearThe regenerate pass runs only when needed and sets status appropriately. With the cache seeding fix above, this flow should work reliably.
wren-ai-service/src/pipelines/generation/relationship_recommendation.py (2)
21-72: Top-level prompts moved to module scope; Engine removed — good simplificationCentralizing
system_promptanduser_prompt_templateand removing the Engine dependency aligns with the broader engine-agnostic refactor and reduces configuration surface.
137-147: Validation simplified to type filtering — matches the prompt contractFiltering by
RelationType.is_includeis sufficient now that the model is constrained by JSON schema.wren-ai-service/src/web/v1/routers/ask_feedbacks.py (3)
39-44: Background task wiring passes metadata correctlyScheduling
ask_feedbackwith serializedservice_metadatalooks good; this keeps the background context minimal and serializable.
47-59: Stop endpoint is straightforward and consistent with the service APISetting
query_idfrom the path and dispatching a background task to stop aligns with the service contract.
62-70: Result polling endpoint is minimal and correctDelegates to the service method and returns the Pydantic model, which is appropriate for a router layer.
wren-ai-service/src/providers/llm/litellm.py (2)
76-82: Image URL plumbing LGTMPlumbing
image_urlthrough_runandChatMessage.from_useraligns with the new formatter and enables vision use-cases.
93-96: Formatter swap to convert_message_to_openai_format LGTMReplacing the previous haystack formatter with the local
convert_message_to_openai_formatis consistent and centralizes message shaping.
| ## Start of Pipeline | ||
| @observe(capture_input=False) | ||
| def prompt( | ||
| previous_questions: list[str], | ||
| documents: list, | ||
| language: str, | ||
| max_questions: int, | ||
| max_categories: int, | ||
| prompt_builder: PromptBuilder, | ||
| ) -> dict: | ||
| """ | ||
| If previous_questions is provided, the MDL is omitted to allow the LLM to focus on | ||
| generating recommendations based on the question history. This helps provide more | ||
| contextually relevant questions that build on previous questions. | ||
| """ | ||
|
|
||
| _prompt = prompt_builder.run( | ||
| documents=documents, | ||
| previous_questions=previous_questions, | ||
| language=language, | ||
| max_questions=max_questions, | ||
| max_categories=max_categories, | ||
| ) | ||
| return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Categories are ignored in the prompt; wire them through to the PromptBuilder
The prompt node does not accept/pass categories, so the second pass with category constraints has no effect on the generated prompt. This will degrade the intended fill-up logic.
Apply this diff to accept and forward categories:
@observe(capture_input=False)
def prompt(
- previous_questions: list[str],
- documents: list,
+ previous_questions: list[str],
+ documents: list,
+ categories: list[str],
language: str,
max_questions: int,
max_categories: int,
prompt_builder: PromptBuilder,
) -> dict:
@@
- _prompt = prompt_builder.run(
- documents=documents,
- previous_questions=previous_questions,
- language=language,
- max_questions=max_questions,
- max_categories=max_categories,
- )
+ _prompt = prompt_builder.run(
+ documents=documents,
+ previous_questions=previous_questions,
+ categories=categories,
+ language=language,
+ max_questions=max_questions,
+ max_categories=max_categories,
+ )
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ## Start of Pipeline | |
| @observe(capture_input=False) | |
| def prompt( | |
| previous_questions: list[str], | |
| documents: list, | |
| language: str, | |
| max_questions: int, | |
| max_categories: int, | |
| prompt_builder: PromptBuilder, | |
| ) -> dict: | |
| """ | |
| If previous_questions is provided, the MDL is omitted to allow the LLM to focus on | |
| generating recommendations based on the question history. This helps provide more | |
| contextually relevant questions that build on previous questions. | |
| """ | |
| _prompt = prompt_builder.run( | |
| documents=documents, | |
| previous_questions=previous_questions, | |
| language=language, | |
| max_questions=max_questions, | |
| max_categories=max_categories, | |
| ) | |
| return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} | |
| ## Start of Pipeline | |
| @observe(capture_input=False) | |
| def prompt( | |
| previous_questions: list[str], | |
| documents: list, | |
| categories: list[str], | |
| language: str, | |
| max_questions: int, | |
| max_categories: int, | |
| prompt_builder: PromptBuilder, | |
| ) -> dict: | |
| """ | |
| If previous_questions is provided, the MDL is omitted to allow the LLM to focus on | |
| generating recommendations based on the question history. This helps provide more | |
| contextually relevant questions that build on previous questions. | |
| """ | |
| _prompt = prompt_builder.run( | |
| documents=documents, | |
| previous_questions=previous_questions, | |
| categories=categories, | |
| language=language, | |
| max_questions=max_questions, | |
| max_categories=max_categories, | |
| ) | |
| return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} |
🤖 Prompt for AI Agents
In wren-ai-service/src/pipelines/generation/question_recommendation.py around
lines 163 to 186, the prompt function currently ignores categories so
category-constrained second-pass prompts have no effect; modify the function
signature to accept a categories parameter (e.g., categories: list[str] | None),
include it in the docstring/typing, and pass categories through to
prompt_builder.run alongside documents, previous_questions, language,
max_questions, and max_categories; ensure the returned prompt remains cleaned
via clean_up_new_lines and that any callers of this node are updated to provide
categories or handle its absence.
| @observe(capture_input=False) | ||
| def normalized(generate: dict) -> dict: | ||
| def wrapper(text: str) -> list: | ||
| text = text.replace("\n", " ") | ||
| text = " ".join(text.split()) | ||
| try: | ||
| text_list = orjson.loads(text.strip()) | ||
| return text_list | ||
| except orjson.JSONDecodeError as e: | ||
| logger.error(f"Error decoding JSON: {e}") | ||
| return [] # Return an empty list if JSON decoding fails | ||
|
|
||
| reply = generate.get("replies")[0] # Expecting only one reply | ||
| normalized = wrapper(reply) | ||
|
|
||
| return normalized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalized() returns a list on JSON error; this breaks callers expecting a dict
_recommend() reads resp.get("normalized", {}).get("questions", []). Returning a list here (on decode error) will raise at runtime because list has no .get. Ensure normalized always returns a dict.
Apply this diff:
def normalized(generate: dict) -> dict:
- def wrapper(text: str) -> list:
+ def wrapper(text: str) -> dict:
text = text.replace("\n", " ")
text = " ".join(text.split())
try:
- text_list = orjson.loads(text.strip())
- return text_list
+ text_dict = orjson.loads(text.strip())
+ return text_dict if isinstance(text_dict, dict) else {}
except orjson.JSONDecodeError as e:
logger.error(f"Error decoding JSON: {e}")
- return [] # Return an empty list if JSON decoding fails
+ return {} # Return an empty dict if JSON decoding fails
@@
- normalized = wrapper(reply)
-
- return normalized
+ normalized = wrapper(reply)
+ return normalizedOptional hardening:
- Guard for empty
repliesto avoidIndexError.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @observe(capture_input=False) | |
| def normalized(generate: dict) -> dict: | |
| def wrapper(text: str) -> list: | |
| text = text.replace("\n", " ") | |
| text = " ".join(text.split()) | |
| try: | |
| text_list = orjson.loads(text.strip()) | |
| return text_list | |
| except orjson.JSONDecodeError as e: | |
| logger.error(f"Error decoding JSON: {e}") | |
| return [] # Return an empty list if JSON decoding fails | |
| reply = generate.get("replies")[0] # Expecting only one reply | |
| normalized = wrapper(reply) | |
| return normalized | |
| @observe(capture_input=False) | |
| def normalized(generate: dict) -> dict: | |
| def wrapper(text: str) -> dict: | |
| text = text.replace("\n", " ") | |
| text = " ".join(text.split()) | |
| try: | |
| text_dict = orjson.loads(text.strip()) | |
| return text_dict if isinstance(text_dict, dict) else {} | |
| except orjson.JSONDecodeError as e: | |
| logger.error(f"Error decoding JSON: {e}") | |
| return {} # Return an empty dict if JSON decoding fails | |
| reply = generate.get("replies")[0] # Expecting only one reply | |
| normalized = wrapper(reply) | |
| return normalized |
🤖 Prompt for AI Agents
wren-ai-service/src/pipelines/generation/question_recommendation.py around lines
195 to 210: the normalized() function can return a list when JSON decoding fails
which breaks callers expecting a dict; change wrapper and normalized to always
return a dict shape (e.g., {"questions": []} or parsed dict) so callers can
safely call .get, and add a guard that verify generate.get("replies") exists and
is non-empty before indexing (return {"questions": []} if replies missing or
empty); on JSON decode or unexpected parsed types, log the error and return an
appropriate empty dict structure.
| @classmethod | ||
| def from_function(cls, content: str, name: str) -> "ChatMessage": | ||
| """ | ||
| Create a message from a function call. | ||
| :param content: The text content of the message. | ||
| :param name: The name of the function being called. | ||
| :returns: A new ChatMessage instance. | ||
| """ | ||
| return cls(content, ChatRole.FUNCTION, name=name, image_url=None, meta=None) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_function sets meta=None (breaks downstream .meta.update(...))
from_function passes meta=None, overriding the dataclass’s default_factory and causing ChatMessage.meta to be None. Callers that assume a dict (e.g., .meta.update(...)) will crash.
Apply this diff to preserve the dict default:
- return cls(content, ChatRole.FUNCTION, name=name, image_url=None, meta=None)
+ return cls(content, ChatRole.FUNCTION, name=name, image_url=None)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @classmethod | |
| def from_function(cls, content: str, name: str) -> "ChatMessage": | |
| """ | |
| Create a message from a function call. | |
| :param content: The text content of the message. | |
| :param name: The name of the function being called. | |
| :returns: A new ChatMessage instance. | |
| """ | |
| return cls(content, ChatRole.FUNCTION, name=name, image_url=None, meta=None) | |
| @classmethod | |
| def from_function(cls, content: str, name: str) -> "ChatMessage": | |
| """ | |
| Create a message from a function call. | |
| :param content: The text content of the message. | |
| :param name: The name of the function being called. | |
| :returns: A new ChatMessage instance. | |
| """ | |
| return cls(content, ChatRole.FUNCTION, name=name, image_url=None) |
🤖 Prompt for AI Agents
In wren-ai-service/src/providers/llm/__init__.py around lines 79 to 89,
from_function explicitly sets meta=None which overrides the dataclass
default_factory and results in ChatMessage.meta being None; change the
constructor call to omit the meta argument so the dataclass default (an empty
dict) is preserved (e.g., remove meta=None from the return call), ensuring
callers can safely call .meta.update(...).
| allowed_openai_params = generation_kwargs.get( | ||
| "allowed_openai_params", [] | ||
| ) + (["reasoning_effort"] if self._model.startswith("gpt-5") else []) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible “multiple values for keyword argument 'allowed_openai_params'”
You extract allowed_openai_params from generation_kwargs but still pass **generation_kwargs, which may still contain the same key, causing a TypeError. Pop it before forwarding.
- allowed_openai_params = generation_kwargs.get(
- "allowed_openai_params", []
- ) + (["reasoning_effort"] if self._model.startswith("gpt-5") else [])
+ allowed_openai_params = generation_kwargs.pop("allowed_openai_params", [])
+ if self._model.startswith("gpt-5"):
+ allowed_openai_params = [*allowed_openai_params, "reasoning_effort"]Note: The same duplication risk exists for other explicitly-passed kwargs like stream. Consider popping stream as well if you allow callers to set it.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| allowed_openai_params = generation_kwargs.get( | |
| "allowed_openai_params", [] | |
| ) + (["reasoning_effort"] if self._model.startswith("gpt-5") else []) | |
| allowed_openai_params = generation_kwargs.pop("allowed_openai_params", []) | |
| if self._model.startswith("gpt-5"): | |
| allowed_openai_params = [*allowed_openai_params, "reasoning_effort"] |
🤖 Prompt for AI Agents
In wren-ai-service/src/providers/llm/litellm.py around lines 102 to 105, you
build allowed_openai_params from generation_kwargs but still forward
**generation_kwargs which can contain the same key and cause a TypeError; remove
the key(s) from generation_kwargs before forwarding by popping
allowed_openai_params (and other explicitly set keys like stream if applicable)
then use the popped/constructed values when calling the downstream function so
no duplicate keyword arguments are passed.
| query_id = str(uuid.uuid4()) | ||
| ask_feedback_request.query_id = query_id | ||
| service_container.ask_feedback_service._ask_feedback_results[ | ||
| query_id | ||
| ] = AskFeedbackResultResponse( | ||
| status="searching", | ||
| ) | ||
|
|
||
| background_tasks.add_task( | ||
| service_container.ask_feedback_service.ask_feedback, | ||
| ask_feedback_request, | ||
| service_metadata=asdict(service_metadata), | ||
| ) | ||
| return AskFeedbackResponse(query_id=query_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid mutating service internals from the router; add a service API to initialize results
Directly writing to _ask_feedback_results couples the router to the service’s internal storage. Prefer exposing a method on AskFeedbackService to initialize a result (status="searching") and return the query_id.
Example change (service method signature, not a diff here since it spans other files):
- In AskFeedbackService, add:
def init_result(self, query_id: str) -> None - Replace the direct write with:
service_container.ask_feedback_service.init_result(query_id)
This improves encapsulation and eases future refactors (e.g., if storage moves out of process).
🤖 Prompt for AI Agents
In wren-ai-service/src/web/v1/routers/ask_feedbacks.py around lines 31 to 44,
the router currently mutates the service's internal _ask_feedback_results dict;
instead add and call a service API to initialize the result: add an
AskFeedbackService.init_result(self, query_id: str) method that creates and
stores AskFeedbackResultResponse(status="searching") under the given query_id
(encapsulating storage), then replace the direct assignment in the router with a
call to service_container.ask_feedback_service.init_result(query_id) before
scheduling the background task and returning the
AskFeedbackResponse(query_id=query_id).
| maxsize: int = 1_000_000, | ||
| ttl: int = 120, | ||
| ): | ||
| self._pipelines = pipelines | ||
| self._ask_feedback_results: Dict[str, AskFeedbackResultResponse] = TTLCache( | ||
| maxsize=maxsize, ttl=ttl | ||
| ) | ||
| self._allow_sql_functions_retrieval = allow_sql_functions_retrieval | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Reconsider TTLCache defaults to prevent premature eviction and memory spikes
- maxsize=1_000_000 can cause significant memory usage.
- ttl=120 seconds risks evicting in-flight results for longer-running queries, leading to “...is not found” while still processing.
Consider:
- Making both parameters configurable (via settings).
- Increasing TTL to accommodate worst-case pipeline durations.
- Using a smaller default maxsize (e.g., 50_000) and/or per-tenant buckets.
Example:
+from src.config import settings
@@
- allow_sql_functions_retrieval: bool = True,
- maxsize: int = 1_000_000,
- ttl: int = 120,
+ allow_sql_functions_retrieval: bool = True,
+ maxsize: int = getattr(settings, "ask_feedback_cache_maxsize", 50_000),
+ ttl: int = getattr(settings, "ask_feedback_cache_ttl", 600),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| maxsize: int = 1_000_000, | |
| ttl: int = 120, | |
| ): | |
| self._pipelines = pipelines | |
| self._ask_feedback_results: Dict[str, AskFeedbackResultResponse] = TTLCache( | |
| maxsize=maxsize, ttl=ttl | |
| ) | |
| self._allow_sql_functions_retrieval = allow_sql_functions_retrieval | |
| from src.config import settings | |
| allow_sql_functions_retrieval: bool = True, | |
| maxsize: int = getattr(settings, "ask_feedback_cache_maxsize", 50_000), | |
| ttl: int = getattr(settings, "ask_feedback_cache_ttl", 600), | |
| ): | |
| self._pipelines = pipelines | |
| self._ask_feedback_results: Dict[str, AskFeedbackResultResponse] = TTLCache( | |
| maxsize=maxsize, ttl=ttl | |
| ) | |
| self._allow_sql_functions_retrieval = allow_sql_functions_retrieval |
| # Extract results from completed tasks | ||
| _retrieval_result = retrieval_task.get( | ||
| "construct_retrieval_results", {} | ||
| ) | ||
| has_calculated_field = _retrieval_result.get( | ||
| "has_calculated_field", False | ||
| ) | ||
| has_metric = _retrieval_result.get("has_metric", False) | ||
| has_json_field = _retrieval_result.get("has_json_field", False) | ||
| documents = _retrieval_result.get("retrieval_results", []) | ||
| table_ddls = [document.get("table_ddl") for document in documents] | ||
| sql_samples = sql_samples_task["formatted_output"].get("documents", []) | ||
| instructions = instructions_task["formatted_output"].get( | ||
| "documents", [] | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Filter out None values when extracting table_ddl documents
documents may not have table_ddl for every entry, yielding None in contexts. Filter to avoid passing None to downstream pipelines.
- table_ddls = [document.get("table_ddl") for document in documents]
+ table_ddls = [
+ doc.get("table_ddl")
+ for doc in documents
+ if doc and doc.get("table_ddl")
+ ]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Extract results from completed tasks | |
| _retrieval_result = retrieval_task.get( | |
| "construct_retrieval_results", {} | |
| ) | |
| has_calculated_field = _retrieval_result.get( | |
| "has_calculated_field", False | |
| ) | |
| has_metric = _retrieval_result.get("has_metric", False) | |
| has_json_field = _retrieval_result.get("has_json_field", False) | |
| documents = _retrieval_result.get("retrieval_results", []) | |
| table_ddls = [document.get("table_ddl") for document in documents] | |
| sql_samples = sql_samples_task["formatted_output"].get("documents", []) | |
| instructions = instructions_task["formatted_output"].get( | |
| "documents", [] | |
| ) | |
| # Extract results from completed tasks | |
| _retrieval_result = retrieval_task.get( | |
| "construct_retrieval_results", {} | |
| ) | |
| has_calculated_field = _retrieval_result.get( | |
| "has_calculated_field", False | |
| ) | |
| has_metric = _retrieval_result.get("has_metric", False) | |
| has_json_field = _retrieval_result.get("has_json_field", False) | |
| documents = _retrieval_result.get("retrieval_results", []) | |
| table_ddls = [ | |
| doc.get("table_ddl") | |
| for doc in documents | |
| if doc and doc.get("table_ddl") | |
| ] | |
| sql_samples = sql_samples_task["formatted_output"].get("documents", []) | |
| instructions = instructions_task["formatted_output"].get( | |
| "documents", [] | |
| ) |
🤖 Prompt for AI Agents
In wren-ai-service/src/web/v1/services/ask_feedback.py around lines 138 to 153,
the list comprehension that builds table_ddls may include None values when a
document lacks "table_ddl"; update the extraction to filter out None entries
(e.g., only include document.get("table_ddl") when it is not None) so downstream
pipelines never receive None values.
| mdl = orjson.loads(input.mdl) | ||
| retrieval_result = await self._pipelines["db_schema_retrieval"].run( | ||
| tables=[model["name"] for model in mdl["models"]], | ||
| project_id=input.project_id, | ||
| ) | ||
| _retrieval_result = retrieval_result.get("construct_retrieval_results", {}) | ||
| documents = _retrieval_result.get("retrieval_results", []) | ||
| table_ddls = [document.get("table_ddl") for document in documents] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seed the cache entry before kicking off validation to avoid KeyError in _validate_question
_validate_question updates self._cache[request_id], but recommend() never initializes it. Seed the entry before _recommend() to prevent KeyError under load.
Apply this diff:
try:
mdl = orjson.loads(input.mdl)
retrieval_result = await self._pipelines["db_schema_retrieval"].run(
tables=[model["name"] for model in mdl["models"]],
project_id=input.project_id,
)
_retrieval_result = retrieval_result.get("construct_retrieval_results", {})
documents = _retrieval_result.get("retrieval_results", [])
table_ddls = [document.get("table_ddl") for document in documents]
+ # Initialize event state in cache before validation updates occur.
+ self._cache[input.event_id] = self.Event(
+ event_id=input.event_id,
+ request_from=input.request_from,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| mdl = orjson.loads(input.mdl) | |
| retrieval_result = await self._pipelines["db_schema_retrieval"].run( | |
| tables=[model["name"] for model in mdl["models"]], | |
| project_id=input.project_id, | |
| ) | |
| _retrieval_result = retrieval_result.get("construct_retrieval_results", {}) | |
| documents = _retrieval_result.get("retrieval_results", []) | |
| table_ddls = [document.get("table_ddl") for document in documents] | |
| mdl = orjson.loads(input.mdl) | |
| retrieval_result = await self._pipelines["db_schema_retrieval"].run( | |
| tables=[model["name"] for model in mdl["models"]], | |
| project_id=input.project_id, | |
| ) | |
| _retrieval_result = retrieval_result.get("construct_retrieval_results", {}) | |
| documents = _retrieval_result.get("retrieval_results", []) | |
| table_ddls = [document.get("table_ddl") for document in documents] | |
| # Initialize event state in cache before validation updates occur. | |
| self._cache[input.event_id] = self.Event( | |
| event_id=input.event_id, | |
| request_from=input.request_from, | |
| ) |
🤖 Prompt for AI Agents
In wren-ai-service/src/web/v1/services/question_recommendation.py around lines
193 to 201, recommend() can call _validate_question which updates
self._cache[request_id] but recommend() never initializes that key, which can
lead to a KeyError under load; seed the cache entry before any async work or
before calling _recommend() by inserting self._cache[request_id] = {} (or the
minimal dict structure _validate_question expects) immediately after you compute
request_id so subsequent validations can safely update the dict.
Summary by CodeRabbit
New Features
Improvements
Chores