diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml index f03d020e5..ac96974af 100644 --- a/.github/workflows/test_server.yml +++ b/.github/workflows/test_server.yml @@ -34,7 +34,7 @@ jobs: uv run -m pytest -v tests docker-amd64: - runs-on: linux-amd64 + runs-on: [linux-amd64] concurrency: group: docker-amd64-${{ github.ref }} cancel-in-progress: true @@ -52,12 +52,14 @@ jobs: github-token: ${{ secrets.GHA_CACHE_TOKEN }} docker-arm64: - runs-on: linux-arm64 + runs-on: [linux-arm64] concurrency: group: docker-arm64-${{ github.ref }} cancel-in-progress: true steps: - uses: actions/checkout@v4 + - name: Wait for Docker daemon + run: while ! docker version; do sleep 1; done - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Build ARM64 diff --git a/server/reflector/llm.py b/server/reflector/llm.py index 50077b4b8..1ea7dd44a 100644 --- a/server/reflector/llm.py +++ b/server/reflector/llm.py @@ -1,42 +1,23 @@ import logging from contextvars import ContextVar -from typing import Generic, Type, TypeVar +from typing import Type, TypeVar from uuid import uuid4 from llama_index.core import Settings -from llama_index.core.output_parsers import PydanticOutputParser +from llama_index.core.prompts import PromptTemplate from llama_index.core.response_synthesizers import TreeSummarize -from llama_index.core.workflow import ( - Context, - Event, - StartEvent, - StopEvent, - Workflow, - step, -) from llama_index.llms.openai_like import OpenAILike from pydantic import BaseModel, ValidationError -from workflows.errors import WorkflowTimeoutError from reflector.utils.retry import retry T = TypeVar("T", bound=BaseModel) -OutputT = TypeVar("OutputT", bound=BaseModel) # Session ID for LiteLLM request grouping - set per processing run llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None) logger = logging.getLogger(__name__) -STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """ -Based on the following analysis, provide the information in the requested JSON format: - -Analysis: -{analysis} - -{format_instructions} -""" - class LLMParseError(Exception): """Raised when LLM output cannot be parsed after retries.""" @@ -50,157 +31,6 @@ def __init__(self, output_cls: Type[BaseModel], error_msg: str, attempts: int): ) -class ExtractionDone(Event): - """Event emitted when LLM JSON formatting completes.""" - - output: str - - -class ValidationErrorEvent(Event): - """Event emitted when validation fails.""" - - error: str - wrong_output: str - - -class StructuredOutputWorkflow(Workflow, Generic[OutputT]): - """Workflow for structured output extraction with validation retry. - - This workflow handles parse/validation retries only. Network error retries - are handled internally by Settings.llm (OpenAILike max_retries=3). - The caller should NOT wrap this workflow in additional retry logic. - """ - - def __init__( - self, - output_cls: Type[OutputT], - max_retries: int = 3, - **kwargs, - ): - super().__init__(**kwargs) - self.output_cls: Type[OutputT] = output_cls - self.max_retries = max_retries - self.output_parser = PydanticOutputParser(output_cls) - - @step - async def extract( - self, ctx: Context, ev: StartEvent | ValidationErrorEvent - ) -> StopEvent | ExtractionDone: - """Extract structured data from text using two-step LLM process. - - Step 1 (first call only): TreeSummarize generates text analysis - Step 2 (every call): Settings.llm.acomplete formats analysis as JSON - """ - current_retries = await ctx.store.get("retries", default=0) - await ctx.store.set("retries", current_retries + 1) - - if current_retries >= self.max_retries: - last_error = await ctx.store.get("last_error", default=None) - logger.error( - f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}" - ) - return StopEvent(result={"error": last_error, "attempts": current_retries}) - - if isinstance(ev, StartEvent): - # First call: run TreeSummarize to get analysis, store in context - prompt = ev.get("prompt") - texts = ev.get("texts") - tone_name = ev.get("tone_name") - if not prompt or not isinstance(texts, list): - raise ValueError( - "StartEvent must contain 'prompt' (str) and 'texts' (list)" - ) - - summarizer = TreeSummarize(verbose=False) - analysis = await summarizer.aget_response( - prompt, texts, tone_name=tone_name - ) - await ctx.store.set("analysis", str(analysis)) - reflection = "" - else: - # Retry: reuse analysis from context - analysis = await ctx.store.get("analysis") - if not analysis: - raise RuntimeError("Internal error: analysis not found in context") - - wrong_output = ev.wrong_output - if len(wrong_output) > 2000: - wrong_output = wrong_output[:2000] + "... [truncated]" - reflection = ( - f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n" - f"Error:\n{ev.error}\n\n" - "Please try again. Return ONLY valid JSON matching the schema above, " - "with no markdown formatting or extra text." - ) - - # Step 2: Format analysis as JSON using LLM completion - format_instructions = self.output_parser.format( - "Please structure the above information in the following JSON format:" - ) - - json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format( - analysis=analysis, - format_instructions=format_instructions + reflection, - ) - - # Network retries handled by OpenAILike (max_retries=3) - # response_format enables grammar-based constrained decoding on backends - # that support it (DMR/llama.cpp, vLLM, Ollama, OpenAI). - response = await Settings.llm.acomplete( - json_prompt, - response_format={ - "type": "json_schema", - "json_schema": { - "name": self.output_cls.__name__, - "schema": self.output_cls.model_json_schema(), - }, - }, - ) - return ExtractionDone(output=response.text) - - @step - async def validate( - self, ctx: Context, ev: ExtractionDone - ) -> StopEvent | ValidationErrorEvent: - """Validate extracted output against Pydantic schema.""" - raw_output = ev.output - retries = await ctx.store.get("retries", default=0) - - try: - parsed = self.output_parser.parse(raw_output) - if retries > 1: - logger.info( - f"LLM parse succeeded on attempt {retries}/{self.max_retries} " - f"for {self.output_cls.__name__}" - ) - return StopEvent(result={"success": parsed}) - - except (ValidationError, ValueError) as e: - error_msg = self._format_error(e, raw_output) - await ctx.store.set("last_error", error_msg) - - logger.error( - f"LLM parse error (attempt {retries}/{self.max_retries}): " - f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}" - ) - - return ValidationErrorEvent( - error=error_msg, - wrong_output=raw_output, - ) - - def _format_error(self, error: Exception, raw_output: str) -> str: - """Format error for LLM feedback.""" - if isinstance(error, ValidationError): - error_messages = [] - for err in error.errors(): - field = ".".join(str(loc) for loc in err["loc"]) - error_messages.append(f"- {err['msg']} in field '{field}'") - return "Schema validation errors:\n" + "\n".join(error_messages) - else: - return f"Parse error: {str(error)}" - - class LLM: def __init__( self, settings, temperature: float = 0.4, max_tokens: int | None = None @@ -225,7 +55,7 @@ def _configure_llamaindex(self): api_key=self.api_key, context_window=self.context_window, is_chat_model=True, - is_function_calling_model=False, + is_function_calling_model=True, temperature=self.temperature, max_tokens=self.max_tokens, timeout=self.settings_obj.LLM_REQUEST_TIMEOUT, @@ -248,36 +78,91 @@ async def get_structured_response( tone_name: str | None = None, timeout: int | None = None, ) -> T: - """Get structured output from LLM with validation retry via Workflow.""" - if timeout is None: - timeout = self.settings_obj.LLM_STRUCTURED_RESPONSE_TIMEOUT + """Get structured output from LLM using tool-call with reflection retry. - async def run_workflow(): - workflow = StructuredOutputWorkflow( - output_cls=output_cls, - max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1, - timeout=timeout, - ) + Uses astructured_predict (function-calling / tool-call mode) for the + first attempt. On ValidationError or parse failure the wrong output + and error are fed back as a reflection prompt and the call is retried + up to LLM_PARSE_MAX_RETRIES times. - result = await workflow.run( - prompt=prompt, - texts=texts, - tone_name=tone_name, + The outer retry() wrapper handles transient network errors with + exponential back-off. + """ + max_retries = self.settings_obj.LLM_PARSE_MAX_RETRIES + + async def _call_with_reflection(): + # Build full prompt: instruction + source texts + if texts: + texts_block = "\n\n".join(texts) + full_prompt = f"{prompt}\n\n{texts_block}" + else: + full_prompt = prompt + + prompt_tmpl = PromptTemplate("{user_prompt}") + last_error: str | None = None + + for attempt in range(1, max_retries + 2): # +2: first try + retries + try: + if attempt == 1: + result = await Settings.llm.astructured_predict( + output_cls, prompt_tmpl, user_prompt=full_prompt + ) + else: + reflection_tmpl = PromptTemplate( + "{user_prompt}\n\n{reflection}" + ) + result = await Settings.llm.astructured_predict( + output_cls, + reflection_tmpl, + user_prompt=full_prompt, + reflection=reflection, + ) + + if attempt > 1: + logger.info( + f"LLM structured_predict succeeded on attempt " + f"{attempt}/{max_retries + 1} for {output_cls.__name__}" + ) + return result + + except (ValidationError, ValueError) as e: + wrong_output = str(e) + if len(wrong_output) > 2000: + wrong_output = wrong_output[:2000] + "... [truncated]" + + last_error = self._format_validation_error(e) + reflection = ( + f"Your previous response could not be parsed.\n\n" + f"Error:\n{last_error}\n\n" + "Please try again and return valid data matching the schema." + ) + + logger.error( + f"LLM parse error (attempt {attempt}/{max_retries + 1}): " + f"{type(e).__name__}: {e}\n" + f"Raw response: {wrong_output[:500]}" + ) + + raise LLMParseError( + output_cls=output_cls, + error_msg=last_error or "Max retries exceeded", + attempts=max_retries + 1, ) - if "error" in result: - error_msg = result["error"] or "Max retries exceeded" - raise LLMParseError( - output_cls=output_cls, - error_msg=error_msg, - attempts=result.get("attempts", 0), - ) - - return result["success"] - - return await retry(run_workflow)( + return await retry(_call_with_reflection)( retry_attempts=3, retry_backoff_interval=1.0, retry_backoff_max=30.0, - retry_ignore_exc_types=(WorkflowTimeoutError,), + retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError), ) + + @staticmethod + def _format_validation_error(error: Exception) -> str: + """Format a validation/parse error for LLM reflection feedback.""" + if isinstance(error, ValidationError): + error_messages = [] + for err in error.errors(): + field = ".".join(str(loc) for loc in err["loc"]) + error_messages.append(f"- {err['msg']} in field '{field}'") + return "Schema validation errors:\n" + "\n".join(error_messages) + return f"Parse error: {str(error)}" diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index e775e0736..d8ea313be 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -14,10 +14,12 @@ class TopicResponse(BaseModel): title: str = Field( description="A descriptive title for the topic being discussed", validation_alias=AliasChoices("title", "Title"), + min_length=8, ) summary: str = Field( description="A concise 1-2 sentence summary of the discussion", validation_alias=AliasChoices("summary", "Summary"), + min_length=8, ) diff --git a/server/tests/test_llm_retry.py b/server/tests/test_llm_retry.py index 5a43c8c59..5c28ff5f8 100644 --- a/server/tests/test_llm_retry.py +++ b/server/tests/test_llm_retry.py @@ -1,13 +1,11 @@ -"""Tests for LLM parse error recovery using llama-index Workflow""" +"""Tests for LLM structured output with astructured_predict + reflection retry""" -from time import monotonic -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest -from pydantic import BaseModel, Field -from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError +from pydantic import BaseModel, Field, ValidationError -from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow +from reflector.llm import LLM, LLMParseError from reflector.utils.retry import RetryException @@ -19,50 +17,42 @@ class TestResponse(BaseModel): confidence: float = Field(description="Confidence score", ge=0, le=1) -def make_completion_response(text: str): - """Create a mock CompletionResponse with .text attribute""" - response = MagicMock() - response.text = text - return response - - class TestLLMParseErrorRecovery: - """Test parse error recovery with Workflow feedback loop""" + """Test parse error recovery with astructured_predict reflection loop""" @pytest.mark.asyncio async def test_parse_error_recovery_with_feedback(self, test_settings): - """Test that parse errors trigger retry with error feedback""" + """Test that parse errors trigger retry with reflection prompt""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - # TreeSummarize returns plain text analysis (step 1) - mock_summarizer.aget_response = AsyncMock( - return_value="The analysis shows a test with summary and high confidence." - ) + call_count = {"count": 0} - call_count = {"count": 0} - - async def acomplete_handler(prompt, *args, **kwargs): - call_count["count"] += 1 - if call_count["count"] == 1: - # First JSON formatting call returns invalid JSON - return make_completion_response('{"title": "Test"}') - else: - # Second call should have error feedback in prompt - assert "Your previous response could not be parsed:" in prompt - assert '{"title": "Test"}' in prompt - assert "Error:" in prompt - assert "Please try again" in prompt - return make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' - ) - - mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + # First call: raise ValidationError (missing fields) + raise ValidationError.from_exception_data( + title="TestResponse", + line_errors=[ + { + "type": "missing", + "loc": ("summary",), + "msg": "Field required", + "input": {"title": "Test"}, + } + ], + ) + else: + # Second call: should have reflection in the prompt + assert "reflection" in kwargs + assert "could not be parsed" in kwargs["reflection"] + assert "Error:" in kwargs["reflection"] + return TestResponse(title="Test", summary="Summary", confidence=0.95) + + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=astructured_predict_handler + ) result = await llm.get_structured_response( prompt="Test prompt", texts=["Test text"], output_cls=TestResponse @@ -71,8 +61,6 @@ async def acomplete_handler(prompt, *args, **kwargs): assert result.title == "Test" assert result.summary == "Summary" assert result.confidence == 0.95 - # TreeSummarize called once, Settings.llm.acomplete called twice - assert mock_summarizer.aget_response.call_count == 1 assert call_count["count"] == 2 @pytest.mark.asyncio @@ -80,56 +68,61 @@ async def test_max_parse_retry_attempts(self, test_settings): """Test that parse error retry stops after max attempts""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - - # Always return invalid JSON from acomplete - mock_settings.llm.acomplete = AsyncMock( - return_value=make_completion_response( - '{"invalid": "missing required fields"}' - ) + # Always raise ValidationError + async def always_fail(output_cls, prompt_tmpl, **kwargs): + raise ValidationError.from_exception_data( + title="TestResponse", + line_errors=[ + { + "type": "missing", + "loc": ("summary",), + "msg": "Field required", + "input": {}, + } + ], ) + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock(side_effect=always_fail) + with pytest.raises(LLMParseError, match="Failed to parse"): await llm.get_structured_response( prompt="Test prompt", texts=["Test text"], output_cls=TestResponse ) expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1 - # TreeSummarize called once, acomplete called max_retries times - assert mock_summarizer.aget_response.call_count == 1 - assert mock_settings.llm.acomplete.call_count == expected_attempts + assert mock_settings.llm.astructured_predict.call_count == expected_attempts @pytest.mark.asyncio async def test_raw_response_logging_on_parse_error(self, test_settings, caplog): """Test that raw response is logged when parse error occurs""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + call_count = {"count": 0} + + async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + raise ValidationError.from_exception_data( + title="TestResponse", + line_errors=[ + { + "type": "missing", + "loc": ("summary",), + "msg": "Field required", + "input": {"title": "Test"}, + } + ], + ) + return TestResponse(title="Test", summary="Summary", confidence=0.95) + with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, patch("reflector.llm.Settings") as mock_settings, caplog.at_level("ERROR"), ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - - call_count = {"count": 0} - - async def acomplete_handler(*args, **kwargs): - call_count["count"] += 1 - if call_count["count"] == 1: - return make_completion_response('{"title": "Test"}') # Invalid - return make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' - ) - - mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=astructured_predict_handler + ) result = await llm.get_structured_response( prompt="Test prompt", texts=["Test text"], output_cls=TestResponse @@ -143,35 +136,45 @@ async def acomplete_handler(*args, **kwargs): @pytest.mark.asyncio async def test_multiple_validation_errors_in_feedback(self, test_settings): - """Test that validation errors are included in feedback""" + """Test that validation errors are included in reflection feedback""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - - call_count = {"count": 0} - - async def acomplete_handler(prompt, *args, **kwargs): - call_count["count"] += 1 - if call_count["count"] == 1: - # Missing title and summary - return make_completion_response('{"confidence": 0.5}') - else: - # Should have schema validation errors in prompt - assert ( - "Schema validation errors" in prompt - or "error" in prompt.lower() - ) - return make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' - ) - - mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + call_count = {"count": 0} + + async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + # Missing title and summary + raise ValidationError.from_exception_data( + title="TestResponse", + line_errors=[ + { + "type": "missing", + "loc": ("title",), + "msg": "Field required", + "input": {}, + }, + { + "type": "missing", + "loc": ("summary",), + "msg": "Field required", + "input": {}, + }, + ], + ) + else: + # Should have schema validation errors in reflection + assert "reflection" in kwargs + assert ( + "Schema validation errors" in kwargs["reflection"] + or "error" in kwargs["reflection"].lower() + ) + return TestResponse(title="Test", summary="Summary", confidence=0.95) + + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=astructured_predict_handler + ) result = await llm.get_structured_response( prompt="Test prompt", texts=["Test text"], output_cls=TestResponse @@ -185,17 +188,10 @@ async def test_success_on_first_attempt(self, test_settings): """Test that no retry happens when first attempt succeeds""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - - mock_settings.llm.acomplete = AsyncMock( - return_value=make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + return_value=TestResponse( + title="Test", summary="Summary", confidence=0.95 ) ) @@ -206,274 +202,175 @@ async def test_success_on_first_attempt(self, test_settings): assert result.title == "Test" assert result.summary == "Summary" assert result.confidence == 0.95 - assert mock_summarizer.aget_response.call_count == 1 - assert mock_settings.llm.acomplete.call_count == 1 + assert mock_settings.llm.astructured_predict.call_count == 1 -class TestStructuredOutputWorkflow: - """Direct tests for the StructuredOutputWorkflow""" +class TestNetworkErrorRetries: + """Test that network errors are retried by the outer retry() wrapper""" @pytest.mark.asyncio - async def test_workflow_retries_on_validation_error(self): - """Test workflow retries when validation fails""" - workflow = StructuredOutputWorkflow( - output_cls=TestResponse, - max_retries=3, - timeout=30, - ) - - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - - call_count = {"count": 0} - - async def acomplete_handler(*args, **kwargs): - call_count["count"] += 1 - if call_count["count"] < 2: - return make_completion_response('{"title": "Only title"}') - return make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.9}' - ) - - mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) - - result = await workflow.run( - prompt="Extract data", - texts=["Some text"], - tone_name=None, - ) - - assert "success" in result - assert result["success"].title == "Test" - assert call_count["count"] == 2 + async def test_network_error_retried_by_outer_wrapper(self, test_settings): + """Test that network errors trigger the outer retry wrapper""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - @pytest.mark.asyncio - async def test_workflow_returns_error_after_max_retries(self): - """Test workflow returns error after exhausting retries""" - workflow = StructuredOutputWorkflow( - output_cls=TestResponse, - max_retries=2, - timeout=30, - ) + call_count = {"count": 0} - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + raise ConnectionError("Connection refused") + return TestResponse(title="Test", summary="Summary", confidence=0.95) - # Always return invalid JSON - mock_settings.llm.acomplete = AsyncMock( - return_value=make_completion_response('{"invalid": true}') + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=astructured_predict_handler ) - result = await workflow.run( - prompt="Extract data", - texts=["Some text"], - tone_name=None, + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse ) - assert "error" in result - # TreeSummarize called once, acomplete called max_retries times - assert mock_summarizer.aget_response.call_count == 1 - assert mock_settings.llm.acomplete.call_count == 2 - - -class TestNetworkErrorRetries: - """Test that network error retries are handled by OpenAILike, not Workflow""" + assert result.title == "Test" + assert call_count["count"] == 2 @pytest.mark.asyncio - async def test_network_error_propagates_after_openai_retries(self, test_settings): - """Test that network errors are retried by OpenAILike and then propagate. - - Network retries are handled by OpenAILike (max_retries=3), not by our - StructuredOutputWorkflow. This test verifies that network errors propagate - up after OpenAILike exhausts its retries. - """ + async def test_network_error_exhausts_retries(self, test_settings): + """Test that persistent network errors exhaust retry attempts""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - - # Simulate network error from acomplete (after OpenAILike retries exhausted) - network_error = ConnectionError("Connection refused") - mock_settings.llm.acomplete = AsyncMock(side_effect=network_error) + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=ConnectionError("Connection refused") + ) - # Network error wrapped in WorkflowRuntimeError - with pytest.raises(WorkflowRuntimeError, match="Connection refused"): + with pytest.raises(RetryException, match="Retry attempts exceeded"): await llm.get_structured_response( prompt="Test prompt", texts=["Test text"], output_cls=TestResponse ) - # acomplete called only once - network error propagates, not retried by Workflow - assert mock_settings.llm.acomplete.call_count == 1 + # 3 retry attempts + assert mock_settings.llm.astructured_predict.call_count == 3 - @pytest.mark.asyncio - async def test_network_error_not_retried_by_workflow(self, test_settings): - """Test that Workflow does NOT retry network errors (OpenAILike handles those). - - This verifies the separation of concerns: - - StructuredOutputWorkflow: retries parse/validation errors - - OpenAILike: retries network errors (internally, max_retries=3) - """ - workflow = StructuredOutputWorkflow( - output_cls=TestResponse, - max_retries=3, - timeout=30, - ) - with ( - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") +class TestTextsInclusion: + """Test that texts parameter is included in the prompt sent to astructured_predict""" - # Network error should propagate immediately, not trigger Workflow retry - mock_settings.llm.acomplete = AsyncMock( - side_effect=TimeoutError("Request timed out") - ) + @pytest.mark.asyncio + async def test_texts_included_in_prompt(self, test_settings): + """Test that texts content is appended to the prompt for astructured_predict""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - # Network error wrapped in WorkflowRuntimeError - with pytest.raises(WorkflowRuntimeError, match="Request timed out"): - await workflow.run( - prompt="Extract data", - texts=["Some text"], - tone_name=None, - ) + captured_prompts = [] - # Only called once - Workflow doesn't retry network errors - assert mock_settings.llm.acomplete.call_count == 1 + async def capture_prompt(output_cls, prompt_tmpl, **kwargs): + captured_prompts.append(kwargs.get("user_prompt", "")) + return TestResponse(title="Test", summary="Summary", confidence=0.95) + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=capture_prompt + ) + + await llm.get_structured_response( + prompt="Identify all participants", + texts=["Alice: Hello everyone", "Bob: Hi Alice"], + output_cls=TestResponse, + ) -class TestWorkflowTimeoutRetry: - """Test timeout retry mechanism in get_structured_response""" + assert len(captured_prompts) == 1 + prompt_sent = captured_prompts[0] + assert "Identify all participants" in prompt_sent + assert "Alice: Hello everyone" in prompt_sent + assert "Bob: Hi Alice" in prompt_sent @pytest.mark.asyncio - async def test_timeout_retry_succeeds_on_retry(self, test_settings): - """Test that WorkflowTimeoutError triggers retry and succeeds""" + async def test_empty_texts_uses_prompt_only(self, test_settings): + """Test that empty texts list sends only the prompt""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - call_count = {"count": 0} + captured_prompts = [] - async def workflow_run_side_effect(*args, **kwargs): - call_count["count"] += 1 - if call_count["count"] == 1: - raise WorkflowTimeoutError("Operation timed out after 120 seconds") - return { - "success": TestResponse( - title="Test", summary="Summary", confidence=0.95 - ) - } + async def capture_prompt(output_cls, prompt_tmpl, **kwargs): + captured_prompts.append(kwargs.get("user_prompt", "")) + return TestResponse(title="Test", summary="Summary", confidence=0.95) - with ( - patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class, - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_workflow = MagicMock() - mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect) - mock_workflow_class.return_value = mock_workflow - - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - mock_settings.llm.acomplete = AsyncMock( - return_value=make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' - ) + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=capture_prompt ) - result = await llm.get_structured_response( - prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + await llm.get_structured_response( + prompt="Identify all participants", + texts=[], + output_cls=TestResponse, ) - assert result.title == "Test" - assert result.summary == "Summary" - assert call_count["count"] == 2 + assert len(captured_prompts) == 1 + assert captured_prompts[0] == "Identify all participants" @pytest.mark.asyncio - async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings): - """Test that timeout retry stops after max attempts""" + async def test_texts_included_in_reflection_retry(self, test_settings): + """Test that texts are included in the prompt even during reflection retries""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + captured_prompts = [] call_count = {"count": 0} - async def workflow_run_side_effect(*args, **kwargs): + async def capture_and_fail_first(output_cls, prompt_tmpl, **kwargs): call_count["count"] += 1 - raise WorkflowTimeoutError("Operation timed out after 120 seconds") - - with ( - patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class, - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_workflow = MagicMock() - mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect) - mock_workflow_class.return_value = mock_workflow - - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - mock_settings.llm.acomplete = AsyncMock( - return_value=make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + captured_prompts.append(kwargs.get("user_prompt", "")) + if call_count["count"] == 1: + raise ValidationError.from_exception_data( + title="TestResponse", + line_errors=[ + { + "type": "missing", + "loc": ("summary",), + "msg": "Field required", + "input": {}, + } + ], ) + return TestResponse(title="Test", summary="Summary", confidence=0.95) + + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=capture_and_fail_first ) - with pytest.raises(RetryException, match="Retry attempts exceeded"): - await llm.get_structured_response( - prompt="Test prompt", texts=["Test text"], output_cls=TestResponse - ) + await llm.get_structured_response( + prompt="Summarize this", + texts=["The meeting covered project updates"], + output_cls=TestResponse, + ) - assert call_count["count"] == 3 + # Both first attempt and reflection retry should include the texts + assert len(captured_prompts) == 2 + for prompt_sent in captured_prompts: + assert "Summarize this" in prompt_sent + assert "The meeting covered project updates" in prompt_sent + + +class TestReflectionRetryBackoff: + """Test the reflection retry timing behavior""" @pytest.mark.asyncio - async def test_timeout_retry_with_backoff(self, test_settings): - """Test that exponential backoff is applied between retries""" + async def test_value_error_triggers_reflection(self, test_settings): + """Test that ValueError (parse failure) also triggers reflection retry""" llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) - call_times = [] + call_count = {"count": 0} - async def workflow_run_side_effect(*args, **kwargs): - call_times.append(monotonic()) - if len(call_times) < 3: - raise WorkflowTimeoutError("Operation timed out after 120 seconds") - return { - "success": TestResponse( - title="Test", summary="Summary", confidence=0.95 - ) - } + async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + raise ValueError("Could not parse output") + assert "reflection" in kwargs + return TestResponse(title="Test", summary="Summary", confidence=0.95) - with ( - patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class, - patch("reflector.llm.TreeSummarize") as mock_summarize, - patch("reflector.llm.Settings") as mock_settings, - ): - mock_workflow = MagicMock() - mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect) - mock_workflow_class.return_value = mock_workflow - - mock_summarizer = MagicMock() - mock_summarize.return_value = mock_summarizer - mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") - mock_settings.llm.acomplete = AsyncMock( - return_value=make_completion_response( - '{"title": "Test", "summary": "Summary", "confidence": 0.95}' - ) + with patch("reflector.llm.Settings") as mock_settings: + mock_settings.llm.astructured_predict = AsyncMock( + side_effect=astructured_predict_handler ) result = await llm.get_structured_response( @@ -481,8 +378,20 @@ async def workflow_run_side_effect(*args, **kwargs): ) assert result.title == "Test" - if len(call_times) >= 2: - time_between_calls = call_times[1] - call_times[0] - assert ( - time_between_calls >= 1.5 - ), f"Expected ~2s backoff, got {time_between_calls}s" + assert call_count["count"] == 2 + + @pytest.mark.asyncio + async def test_format_validation_error_method(self, test_settings): + """Test _format_validation_error produces correct feedback""" + # ValidationError + try: + TestResponse(title="x", summary="y", confidence=5.0) # confidence > 1 + except ValidationError as e: + result = LLM._format_validation_error(e) + assert "Schema validation errors" in result + assert "confidence" in result + + # ValueError + result = LLM._format_validation_error(ValueError("bad input")) + assert "Parse error:" in result + assert "bad input" in result