Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/test_server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
uv run -m pytest -v tests

docker-amd64:
runs-on: linux-amd64
runs-on: [linux-amd64]
concurrency:
group: docker-amd64-${{ github.ref }}
cancel-in-progress: true
Expand All @@ -52,12 +52,14 @@ jobs:
github-token: ${{ secrets.GHA_CACHE_TOKEN }}

docker-arm64:
runs-on: linux-arm64
runs-on: [linux-arm64]
concurrency:
group: docker-arm64-${{ github.ref }}
cancel-in-progress: true
steps:
- uses: actions/checkout@v4
- name: Wait for Docker daemon
run: while ! docker version; do sleep 1; done
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build ARM64
Expand Down
281 changes: 83 additions & 198 deletions server/reflector/llm.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,23 @@
import logging
from contextvars import ContextVar
from typing import Generic, Type, TypeVar
from typing import Type, TypeVar
from uuid import uuid4

from llama_index.core import Settings
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.prompts import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel, ValidationError
from workflows.errors import WorkflowTimeoutError

from reflector.utils.retry import retry

T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)

# Session ID for LiteLLM request grouping - set per processing run
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)

logger = logging.getLogger(__name__)

STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
Based on the following analysis, provide the information in the requested JSON format:

Analysis:
{analysis}

{format_instructions}
"""


class LLMParseError(Exception):
"""Raised when LLM output cannot be parsed after retries."""
Expand All @@ -50,157 +31,6 @@ def __init__(self, output_cls: Type[BaseModel], error_msg: str, attempts: int):
)


class ExtractionDone(Event):
"""Event emitted when LLM JSON formatting completes."""

output: str


class ValidationErrorEvent(Event):
"""Event emitted when validation fails."""

error: str
wrong_output: str


class StructuredOutputWorkflow(Workflow, Generic[OutputT]):
"""Workflow for structured output extraction with validation retry.

This workflow handles parse/validation retries only. Network error retries
are handled internally by Settings.llm (OpenAILike max_retries=3).
The caller should NOT wrap this workflow in additional retry logic.
"""

def __init__(
self,
output_cls: Type[OutputT],
max_retries: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.output_cls: Type[OutputT] = output_cls
self.max_retries = max_retries
self.output_parser = PydanticOutputParser(output_cls)

@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
"""Extract structured data from text using two-step LLM process.

Step 1 (first call only): TreeSummarize generates text analysis
Step 2 (every call): Settings.llm.acomplete formats analysis as JSON
"""
current_retries = await ctx.store.get("retries", default=0)
await ctx.store.set("retries", current_retries + 1)

if current_retries >= self.max_retries:
last_error = await ctx.store.get("last_error", default=None)
logger.error(
f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}"
)
return StopEvent(result={"error": last_error, "attempts": current_retries})

if isinstance(ev, StartEvent):
# First call: run TreeSummarize to get analysis, store in context
prompt = ev.get("prompt")
texts = ev.get("texts")
tone_name = ev.get("tone_name")
if not prompt or not isinstance(texts, list):
raise ValueError(
"StartEvent must contain 'prompt' (str) and 'texts' (list)"
)

summarizer = TreeSummarize(verbose=False)
analysis = await summarizer.aget_response(
prompt, texts, tone_name=tone_name
)
await ctx.store.set("analysis", str(analysis))
reflection = ""
else:
# Retry: reuse analysis from context
analysis = await ctx.store.get("analysis")
if not analysis:
raise RuntimeError("Internal error: analysis not found in context")

wrong_output = ev.wrong_output
if len(wrong_output) > 2000:
wrong_output = wrong_output[:2000] + "... [truncated]"
reflection = (
f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n"
f"Error:\n{ev.error}\n\n"
"Please try again. Return ONLY valid JSON matching the schema above, "
"with no markdown formatting or extra text."
)

# Step 2: Format analysis as JSON using LLM completion
format_instructions = self.output_parser.format(
"Please structure the above information in the following JSON format:"
)

json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format(
analysis=analysis,
format_instructions=format_instructions + reflection,
)

# Network retries handled by OpenAILike (max_retries=3)
# response_format enables grammar-based constrained decoding on backends
# that support it (DMR/llama.cpp, vLLM, Ollama, OpenAI).
response = await Settings.llm.acomplete(
json_prompt,
response_format={
"type": "json_schema",
"json_schema": {
"name": self.output_cls.__name__,
"schema": self.output_cls.model_json_schema(),
},
},
)
return ExtractionDone(output=response.text)

@step
async def validate(
self, ctx: Context, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
"""Validate extracted output against Pydantic schema."""
raw_output = ev.output
retries = await ctx.store.get("retries", default=0)

try:
parsed = self.output_parser.parse(raw_output)
if retries > 1:
logger.info(
f"LLM parse succeeded on attempt {retries}/{self.max_retries} "
f"for {self.output_cls.__name__}"
)
return StopEvent(result={"success": parsed})

except (ValidationError, ValueError) as e:
error_msg = self._format_error(e, raw_output)
await ctx.store.set("last_error", error_msg)

logger.error(
f"LLM parse error (attempt {retries}/{self.max_retries}): "
f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}"
)

return ValidationErrorEvent(
error=error_msg,
wrong_output=raw_output,
)

def _format_error(self, error: Exception, raw_output: str) -> str:
"""Format error for LLM feedback."""
if isinstance(error, ValidationError):
error_messages = []
for err in error.errors():
field = ".".join(str(loc) for loc in err["loc"])
error_messages.append(f"- {err['msg']} in field '{field}'")
return "Schema validation errors:\n" + "\n".join(error_messages)
else:
return f"Parse error: {str(error)}"


class LLM:
def __init__(
self, settings, temperature: float = 0.4, max_tokens: int | None = None
Expand All @@ -225,7 +55,7 @@ def _configure_llamaindex(self):
api_key=self.api_key,
context_window=self.context_window,
is_chat_model=True,
is_function_calling_model=False,
is_function_calling_model=True,
temperature=self.temperature,
max_tokens=self.max_tokens,
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
Expand All @@ -248,36 +78,91 @@ async def get_structured_response(
tone_name: str | None = None,
timeout: int | None = None,
) -> T:
"""Get structured output from LLM with validation retry via Workflow."""
if timeout is None:
timeout = self.settings_obj.LLM_STRUCTURED_RESPONSE_TIMEOUT
"""Get structured output from LLM using tool-call with reflection retry.

async def run_workflow():
workflow = StructuredOutputWorkflow(
output_cls=output_cls,
max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
timeout=timeout,
)
Uses astructured_predict (function-calling / tool-call mode) for the
first attempt. On ValidationError or parse failure the wrong output
and error are fed back as a reflection prompt and the call is retried
up to LLM_PARSE_MAX_RETRIES times.

result = await workflow.run(
prompt=prompt,
texts=texts,
tone_name=tone_name,
The outer retry() wrapper handles transient network errors with
exponential back-off.
"""
max_retries = self.settings_obj.LLM_PARSE_MAX_RETRIES

async def _call_with_reflection():
# Build full prompt: instruction + source texts
if texts:
texts_block = "\n\n".join(texts)
full_prompt = f"{prompt}\n\n{texts_block}"
else:
full_prompt = prompt

prompt_tmpl = PromptTemplate("{user_prompt}")
last_error: str | None = None

for attempt in range(1, max_retries + 2): # +2: first try + retries
try:
if attempt == 1:
result = await Settings.llm.astructured_predict(
output_cls, prompt_tmpl, user_prompt=full_prompt
)
else:
reflection_tmpl = PromptTemplate(
"{user_prompt}\n\n{reflection}"
)
result = await Settings.llm.astructured_predict(
output_cls,
reflection_tmpl,
user_prompt=full_prompt,
reflection=reflection,
)

if attempt > 1:
logger.info(
f"LLM structured_predict succeeded on attempt "
f"{attempt}/{max_retries + 1} for {output_cls.__name__}"
)
return result

except (ValidationError, ValueError) as e:
wrong_output = str(e)
if len(wrong_output) > 2000:
wrong_output = wrong_output[:2000] + "... [truncated]"

last_error = self._format_validation_error(e)
reflection = (
f"Your previous response could not be parsed.\n\n"
f"Error:\n{last_error}\n\n"
"Please try again and return valid data matching the schema."
)

logger.error(
f"LLM parse error (attempt {attempt}/{max_retries + 1}): "
f"{type(e).__name__}: {e}\n"
f"Raw response: {wrong_output[:500]}"
)

raise LLMParseError(
output_cls=output_cls,
error_msg=last_error or "Max retries exceeded",
attempts=max_retries + 1,
)

if "error" in result:
error_msg = result["error"] or "Max retries exceeded"
raise LLMParseError(
output_cls=output_cls,
error_msg=error_msg,
attempts=result.get("attempts", 0),
)

return result["success"]

return await retry(run_workflow)(
return await retry(_call_with_reflection)(
retry_attempts=3,
retry_backoff_interval=1.0,
retry_backoff_max=30.0,
retry_ignore_exc_types=(WorkflowTimeoutError,),
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
)

@staticmethod
def _format_validation_error(error: Exception) -> str:
"""Format a validation/parse error for LLM reflection feedback."""
if isinstance(error, ValidationError):
error_messages = []
for err in error.errors():
field = ".".join(str(loc) for loc in err["loc"])
error_messages.append(f"- {err['msg']} in field '{field}'")
return "Schema validation errors:\n" + "\n".join(error_messages)
return f"Parse error: {str(error)}"
2 changes: 2 additions & 0 deletions server/reflector/processors/transcript_topic_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class TopicResponse(BaseModel):
title: str = Field(
description="A descriptive title for the topic being discussed",
validation_alias=AliasChoices("title", "Title"),
min_length=8,
)
summary: str = Field(
description="A concise 1-2 sentence summary of the discussion",
validation_alias=AliasChoices("summary", "Summary"),
min_length=8,
)


Expand Down
Loading
Loading